diff --git a/bittensor/commands/stake.py b/bittensor/commands/stake.py index 132529a131..eff415d1a1 100644 --- a/bittensor/commands/stake.py +++ b/bittensor/commands/stake.py @@ -44,19 +44,25 @@ def get_netuid( - cli: "bittensor.cli", subtensor: "bittensor.subtensor" + cli: "bittensor.cli", subtensor: "bittensor.subtensor", prompt: bool = True ) -> Tuple[bool, int]: """Retrieve and validate the netuid from the user or configuration.""" console = Console() - if not cli.config.is_set("netuid"): - try: - cli.config.netuid = int(Prompt.ask("Enter netuid")) - except ValueError: - console.print( - "[red]Invalid input. Please enter a valid integer for netuid.[/red]" - ) - return False, -1 + if not cli.config.is_set("netuid") and prompt: + cli.config.netuid = Prompt.ask("Enter netuid") + try: + cli.config.netuid = int(cli.config.netuid) + except ValueError: + console.print( + "[red]Invalid input. Please enter a valid integer for netuid.[/red]" + ) + return False, -1 netuid = cli.config.netuid + if netuid < 0 or netuid > 65535: + console.print( + "[red]Invalid input. Please enter a valid integer for netuid in subnet range.[/red]" + ) + return False, -1 if not subtensor.subnet_exists(netuid=netuid): console.print( "[red]Network with netuid {} does not exist. Please try again.[/red]".format( @@ -1136,10 +1142,27 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): wallet = bittensor.wallet(config=cli.config) # check all - if not cli.config.is_set("all"): - exists, netuid = get_netuid(cli, subtensor) - if not exists: - return + if cli.config.is_set("all"): + cli.config.netuid = None + cli.config.all = True + elif cli.config.is_set("netuid"): + if cli.config.netuid == "all": + cli.config.all = True + else: + cli.config.netuid = int(cli.config.netuid) + exists, netuid = get_netuid(cli, subtensor) + if not exists: + return + else: + netuid_input = Prompt.ask("Enter netuid or 'all'", default="all") + if netuid_input == "all": + cli.config.netuid = None + cli.config.all = True + else: + cli.config.netuid = int(netuid_input) + exists, netuid = get_netuid(cli, subtensor, False) + if not exists: + return # get parent hotkey hotkey = get_hotkey(wallet, cli.config) @@ -1148,11 +1171,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): return try: - netuids = ( - subtensor.get_all_subnet_netuids() - if cli.config.is_set("all") - else [netuid] - ) + netuids = subtensor.get_all_subnet_netuids() if cli.config.all else [netuid] hotkey_stake = GetChildrenCommand.get_parent_stake_info( console, subtensor, hotkey ) @@ -1236,7 +1255,7 @@ def add_args(parser: argparse.ArgumentParser): parser = parser.add_parser( "get_children", help="""Get child hotkeys on subnet.""" ) - parser.add_argument("--netuid", dest="netuid", type=int, required=False) + parser.add_argument("--netuid", dest="netuid", type=str, required=False) parser.add_argument("--hotkey", dest="hotkey", type=str, required=False) parser.add_argument( "--all", @@ -1294,7 +1313,7 @@ def render_table( # Add columns to the table with specific styles table.add_column("Index", style="bold yellow", no_wrap=True, justify="center") - table.add_column("ChildHotkey", style="bold green") + table.add_column("Child Hotkey", style="bold green") table.add_column("Proportion", style="bold cyan", no_wrap=True, justify="right") table.add_column( "Childkey Take", style="bold blue", no_wrap=True, justify="right" diff --git a/tests/e2e_tests/subcommands/root/test_root_register_add_member_senate.py b/tests/e2e_tests/subcommands/root/test_root_register_add_member_senate.py index 7d45e5abcb..3626b48ce9 100644 --- a/tests/e2e_tests/subcommands/root/test_root_register_add_member_senate.py +++ b/tests/e2e_tests/subcommands/root/test_root_register_add_member_senate.py @@ -13,8 +13,25 @@ from ...utils import setup_wallet +def assert_sequence(lines, sequence): + sequence_ptr = 0 + for line in lines: + words_in_line = set(line.split()) + current_seq_set = sequence[sequence_ptr] + if current_seq_set.issubset(words_in_line): + sequence_ptr += 1 + if sequence_ptr == len(sequence): + break + + assert sequence_ptr == len( + sequence + ), f"Did not find sequence[{sequence_ptr}] = '{sequence[sequence_ptr]}' in output" + + def test_root_register_add_member_senate(local_chain, capsys): logging.info("Testing test_root_register_add_member_senate") + netuid = 1 + # Register root as Alice - the subnet owner alice_keypair, exec_command, wallet = setup_wallet("//Alice") exec_command(RegisterSubnetworkCommand, ["s", "create"]) @@ -26,7 +43,7 @@ def test_root_register_add_member_senate(local_chain, capsys): "s", "register", "--netuid", - "1", + str(netuid), ], ) @@ -45,9 +62,8 @@ def test_root_register_add_member_senate(local_chain, capsys): exec_command(SetTakeCommand, ["r", "set_take", "--take", "0.8"]) - captured = capsys.readouterr() - # Verify subnet 1 created successfully - assert local_chain.query("SubtensorModule", "NetworksAdded", [1]).serialize() + # Verify subnet created successfully + assert local_chain.query("SubtensorModule", "NetworksAdded", [netuid]).serialize() # Query local chain for senate members members = local_chain.query("SenateMembers", "Members").serialize() assert len(members) == 3, f"Expected 3 senate members, found {len(members)}" @@ -68,12 +84,16 @@ def test_root_register_add_member_senate(local_chain, capsys): captured = capsys.readouterr() # assert output is graph Titling "Senate" with names and addresses - assert "Senate" in captured.out - assert "NAME" in captured.out - assert "ADDRESS" in captured.out - assert "5CiPPseXPECbkjWCa6MnjNokrgYjMqmKndv2rSnekmSK2DjL" in captured.out - assert "5DAAnrj7VHTznn2AWBemMuyBwZWs6FNFjdyVXUeYum3PTXFy" in captured.out - assert "5HGjWAeFDfFCWPsjFQdVV2Msvz2XtMktvgocEZcCj68kUMaw" in captured.out + assert_sequence( + captured.out.split("\n"), + ( + {"Senate"}, + {"NAME", "ADDRESS"}, + {"5CiPPseXPECbkjWCa6MnjNokrgYjMqmKndv2rSnekmSK2DjL"}, + {"5DAAnrj7VHTznn2AWBemMuyBwZWs6FNFjdyVXUeYum3PTXFy"}, + {"5HGjWAeFDfFCWPsjFQdVV2Msvz2XtMktvgocEZcCj68kUMaw"}, + ), + ) exec_command( RootRegisterCommand, @@ -110,11 +130,14 @@ def test_root_register_add_member_senate(local_chain, capsys): captured = capsys.readouterr() # assert output is graph Titling "Senate" with names and addresses - - assert "Senate" in captured.out - assert "NAME" in captured.out - assert "ADDRESS" in captured.out - assert "5CiPPseXPECbkjWCa6MnjNokrgYjMqmKndv2rSnekmSK2DjL" in captured.out - assert "5DAAnrj7VHTznn2AWBemMuyBwZWs6FNFjdyVXUeYum3PTXFy" in captured.out - assert "5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY" in captured.out - assert "5HGjWAeFDfFCWPsjFQdVV2Msvz2XtMktvgocEZcCj68kUMaw" in captured.out + assert_sequence( + captured.out.split("\n"), + ( + {"Senate"}, + {"NAME", "ADDRESS"}, + {"5CiPPseXPECbkjWCa6MnjNokrgYjMqmKndv2rSnekmSK2DjL"}, + {"5DAAnrj7VHTznn2AWBemMuyBwZWs6FNFjdyVXUeYum3PTXFy"}, + {"5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY"}, + {"5HGjWAeFDfFCWPsjFQdVV2Msvz2XtMktvgocEZcCj68kUMaw"}, + ), + )