From 4e9369b288640a30026f5057b0dcd6865744f477 Mon Sep 17 00:00:00 2001 From: bhimes Date: Mon, 20 May 2024 15:08:51 +0200 Subject: [PATCH 01/18] [WIP] Backporting the torch functionality --- bittensor/chain_data.py | 144 ++++++++++++++++++++++++++++++++----- bittensor/commands/root.py | 35 ++++++--- bittensor/dendrite.py | 22 ++++-- 3 files changed, 172 insertions(+), 29 deletions(-) diff --git a/bittensor/chain_data.py b/bittensor/chain_data.py index 99f435df59..2034332608 100644 --- a/bittensor/chain_data.py +++ b/bittensor/chain_data.py @@ -15,7 +15,7 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. import bittensor - +import os import json from enum import Enum from dataclasses import dataclass, asdict @@ -25,7 +25,7 @@ from scalecodec.type_registry import load_type_registry_preset from scalecodec.utils.ss58 import ss58_encode -from .utils import networking as net, U16_MAX, U16_NORMALIZED_FLOAT +from .utils import networking as net, U16_MAX, U16_NORMALIZED_FLOAT, maybe_get_torch from .utils.balance import Balance custom_rpc_type_registry = { @@ -264,15 +264,43 @@ def from_neuron_info(cls, neuron_info: dict) -> "AxonInfo": coldkey=neuron_info["coldkey"], ) - def to_parameter_dict(self) -> dict[str, Union[int, str]]: - r"""Returns a dict of the subnet info.""" + def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": + """Returns a torch tensor of the subnet info.""" + return maybe_get_torch().nn.ParameterDict(self.__dict__) + + def _to_parameter_dict_numpy(self) -> dict[str, Union[int, str]]: + """Returns a dict of the subnet info.""" return self.__dict__ + def to_parameter_dict( + self, + ) -> Union[dict[str, Union[int, str]], "torch.nn.ParameterDict"]: + if os.environ.get("USE_TORCH"): + return self._to_parameter_dict_torch() + else: + return self._to_parameter_dict_numpy() + @classmethod - def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "AxonInfo": + def _from_parameter_dict_torch( + cls, parameter_dict: "torch.nn.ParameterDict" + ) -> "AxonInfo": + """Returns an axon_info object from a torch parameter_dict.""" + return cls(**dict(parameter_dict)) + + @classmethod + def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "AxonInfo": r"""Returns an axon_info object from a parameter_dict.""" return cls(**parameter_dict) + @classmethod + def from_parameter_dict( + cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] + ) -> "AxonInfo": + if os.environ.get("USE_TORCH"): + return cls._from_parameter_dict_torch(parameter_dict) + else: + return cls._from_parameter_dict_numpy(parameter_dict) + class ChainDataType(Enum): NeuronInfo = 1 @@ -980,15 +1008,41 @@ def fix_decoded_values(cls, decoded: Dict) -> "SubnetInfo": owner_ss58=ss58_encode(decoded["owner"], bittensor.__ss58_format__), ) - def to_parameter_dict(self) -> dict[str, Any]: - r"""Returns a dict of the subnet info.""" + def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": + """Returns a torch tensor of the subnet info.""" + return maybe_get_torch().nn.ParameterDict(self.__dict__) + + def _to_parameter_dict_numpy(self) -> dict[str, Any]: + """Returns a dict of the subnet info.""" return self.__dict__ + def to_parameter_dict(self) -> Union[dict[str, Any], "torch.nn.ParameterDict"]: + if os.environ.get("USE_TORCH"): + return self._to_parameter_dict_torch() + else: + return self._to_parameter_dict_numpy() + + @classmethod + def _from_parameter_dict_torch( + cls, parameter_dict: "torch.nn.ParameterDict" + ) -> "SubnetInfo": + """Returns a SubnetInfo object from a torch parameter_dict.""" + return cls(**dict(parameter_dict)) + @classmethod - def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "SubnetInfo": + def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "SubnetInfo": r"""Returns a SubnetInfo object from a parameter_dict.""" return cls(**parameter_dict) + @classmethod + def from_parameter_dict( + cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] + ) -> "SubnetInfo": + if os.environ.get("USE_TORCH"): + return cls._from_parameter_dict_torch(parameter_dict) + else: + return cls._from_parameter_dict_numpy(parameter_dict) + @dataclass class SubnetHyperparameters: @@ -1074,15 +1128,45 @@ def fix_decoded_values(cls, decoded: Dict) -> "SubnetHyperparameters": difficulty=decoded["difficulty"], ) - def to_parameter_dict(self) -> dict[str, Union[int, float, bool]]: - r"""Returns a dict of the subnet hyperparameters.""" + def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": + """Returns a torch tensor of the subnet hyperparameters.""" + return maybe_get_torch().nn.ParameterDict(self.__dict__) + + def _to_parameter_dict_numpy(self) -> dict[str, Union[int, float, bool]]: + """Returns a dict of the subnet hyperparameters.""" return self.__dict__ + def to_parameter_dict( + self, + ) -> Union[dict[str, Union[int, float, bool]], "torch.nn.ParameterDict"]: + if os.environ.get("USE_TORCH"): + return self._to_parameter_dict_torch() + else: + return self._to_parameter_dict_numpy() + + @classmethod + def _from_parameter_dict_torch( + cls, parameter_dict: "torch.nn.ParameterDict" + ) -> "SubnetHyperparameters": + """Returns a SubnetHyperparameters object from a torch parameter_dict.""" + return cls(**dict(parameter_dict)) + @classmethod - def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "SubnetInfo": - r"""Returns a SubnetHyperparameters object from a parameter_dict.""" + def _from_parameter_dict_numpy( + cls, parameter_dict: dict[str, Any] + ) -> "SubnetHyperparameters": + """Returns a SubnetHyperparameters object from a parameter_dict.""" return cls(**parameter_dict) + @classmethod + def from_parameter_dict( + cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] + ) -> "SubnetHyperparameters": + if os.environ.get("USE_TORCH"): + return cls._from_parameter_dict_torch(parameter_dict) + else: + return cls._from_parameter_dict_numpy(parameter_dict) + @dataclass class IPInfo: @@ -1137,15 +1221,43 @@ def fix_decoded_values(cls, decoded: Dict) -> "IPInfo": protocol=decoded["ip_type_and_protocol"] & 0xF, ) - def to_parameter_dict(self) -> dict[str, Union[str, int]]: - r"""Returns a dict of the subnet ip info.""" + def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": + """Returns a torch tensor of the subnet info.""" + return maybe_get_torch().nn.ParameterDict(self.__dict__) + + def _to_parameter_dict_numpy(self) -> dict[str, Union[str, int]]: + """Returns a dict of the subnet ip info.""" return self.__dict__ + def to_parameter_dict( + self, + ) -> Union[dict[str, Union[str, int]], "torch.nn.ParameterDict"]: + if os.environ.get("USE_TORCH"): + return self._to_parameter_dict_torch() + else: + return self._to_parameter_dict_numpy() + + @classmethod + def _from_parameter_dict_torch( + cls, parameter_dict: "torch.nn.ParameterDict" + ) -> "IPInfo": + """Returns a IPInfo object from a torch parameter_dict.""" + return cls(**dict(parameter_dict)) + @classmethod - def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "IPInfo": - r"""Returns a IPInfo object from a parameter_dict.""" + def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "IPInfo": + """Returns a IPInfo object from a parameter_dict.""" return cls(**parameter_dict) + @classmethod + def from_parameter_dict( + cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] + ) -> "IPInfo": + if os.environ.get("USE_TORCH"): + return cls._from_parameter_dict_torch(parameter_dict) + else: + return cls._from_parameter_dict_numpy(parameter_dict) + # Senate / Proposal data diff --git a/bittensor/commands/root.py b/bittensor/commands/root.py index 912390cafc..f5785f3ebb 100644 --- a/bittensor/commands/root.py +++ b/bittensor/commands/root.py @@ -15,8 +15,8 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import os import re -import numpy as np import typing import argparse import numpy as np @@ -24,7 +24,7 @@ from typing import List, Optional, Dict from rich.prompt import Prompt from rich.table import Table -from .utils import get_delegates_details, DelegatesDetails +from .utils import get_delegates_details, DelegatesDetails, maybe_get_torch from . import defaults @@ -301,7 +301,11 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): f"Boosting weight for netuid {cli.config.netuid} from {prev_weight} -> {new_weight}" ) my_weights[cli.config.netuid] = new_weight - all_netuids = np.arange(len(my_weights)) + all_netuids = ( + maybe_get_torch().tensor(list(range(len(my_weights)))) + if os.environ.get("USE_TORCH") + else np.arange(len(my_weights)) + ) bittensor.__console__.print("Setting root weights...") subtensor.root_set_weights( @@ -419,7 +423,11 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): my_weights = root.weights[my_uid] my_weights[cli.config.netuid] -= cli.config.amount my_weights[my_weights < 0] = 0 # Ensure weights don't go negative - all_netuids = np.arange(len(my_weights)) + all_netuids = ( + maybe_get_torch().tensor(list(range(len(my_weights)))) + if os.environ.get("USE_TORCH") + else np.arange(len(my_weights)) + ) subtensor.root_set_weights( wallet=wallet, @@ -520,12 +528,21 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): cli.config.weights = Prompt.ask(f"Enter weights (e.g. {example})") # Parse from string - netuids = np.array( - list(map(int, re.split(r"[ ,]+", cli.config.netuids))), dtype=np.int64 + matched_netuids = list(map(int, re.split(r"[ ,]+", cli.config.netuids))) + netuids = ( + maybe_get_torch().tensor(matched_netuids, dtype=maybe_get_torch().long) + if os.environ.get("USE_TORCH") + else np.array(matched_netuids, dtype=np.int64) ) - weights = np.array( - list(map(float, re.split(r"[ ,]+", cli.config.weights))), - dtype=np.float32, + + matched_weights = list(map(float, re.split(r"[ ,]+", cli.config.weights))) + weights = ( + maybe_get_torch().tensor(matched_weights, dtype=maybe_get_torch().float32) + if os.environ.get("USE_TORCH") + else np.array( + matched_weights, + dtype=np.float32, + ) ) # Run the set weights operation. diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 9a9202ab31..684a7ec92e 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -20,14 +20,16 @@ from __future__ import annotations import asyncio +import os import uuid import time import aiohttp import bittensor from typing import Union, Optional, List, Union, AsyncGenerator, Any +from .utils import maybe_get_torch -class dendrite: +class DendriteMixin: """ The Dendrite class represents the abstracted implementation of a network client module. @@ -121,9 +123,6 @@ def __init__( self._session: Optional[aiohttp.ClientSession] = None - async def __call__(self, *args, **kwargs): - return await self.forward(*args, **kwargs) - @property async def session(self) -> aiohttp.ClientSession: """ @@ -808,3 +807,18 @@ def __del__(self): del dendrite # This will implicitly invoke the __del__ method and close the session. """ self.close_session() + + +if os.environ.get("USE_TORCH"): + class dendrite(maybe_get_torch().nn.module, DendriteMixin): + def __init__(self): + maybe_get_torch().nn.module.__init__(self) + DendriteMixin.__init__(self) +else: + class dendrite(DendriteMixin): + def __init__(self): + DendriteMixin.__init__(self) + + async def __call__(self, *args, **kwargs): + return await self.forward(*args, **kwargs) + From 346c1894a29d4b76196e2e2c663f4b3fea9213c8 Mon Sep 17 00:00:00 2001 From: bhimes Date: Mon, 20 May 2024 15:21:35 +0200 Subject: [PATCH 02/18] [WIP] Black format. --- bittensor/dendrite.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 684a7ec92e..2040458ccd 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -810,15 +810,17 @@ def __del__(self): if os.environ.get("USE_TORCH"): + class dendrite(maybe_get_torch().nn.module, DendriteMixin): def __init__(self): maybe_get_torch().nn.module.__init__(self) DendriteMixin.__init__(self) + else: + class dendrite(DendriteMixin): def __init__(self): DendriteMixin.__init__(self) async def __call__(self, *args, **kwargs): return await self.forward(*args, **kwargs) - From d1311f65b3116dab35c6d3ce2769bf2178544f90 Mon Sep 17 00:00:00 2001 From: bhimes Date: Mon, 20 May 2024 18:25:30 +0200 Subject: [PATCH 03/18] [WIP] check-in --- bittensor/chain_data.py | 99 +++++++++++++++------------- bittensor/dendrite.py | 6 +- bittensor/extrinsics/registration.py | 8 +-- bittensor/metagraph.py | 29 ++++---- bittensor/utils/__init__.py | 51 +++++++++----- bittensor/utils/registration.py | 21 +++++- 6 files changed, 123 insertions(+), 91 deletions(-) diff --git a/bittensor/chain_data.py b/bittensor/chain_data.py index 2034332608..9b81d87803 100644 --- a/bittensor/chain_data.py +++ b/bittensor/chain_data.py @@ -25,7 +25,7 @@ from scalecodec.type_registry import load_type_registry_preset from scalecodec.utils.ss58 import ss58_encode -from .utils import networking as net, U16_MAX, U16_NORMALIZED_FLOAT, maybe_get_torch +from .utils import networking as net, U16_MAX, U16_NORMALIZED_FLOAT, torch from .utils.balance import Balance custom_rpc_type_registry = { @@ -264,42 +264,43 @@ def from_neuron_info(cls, neuron_info: dict) -> "AxonInfo": coldkey=neuron_info["coldkey"], ) - def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": - """Returns a torch tensor of the subnet info.""" - return maybe_get_torch().nn.ParameterDict(self.__dict__) - - def _to_parameter_dict_numpy(self) -> dict[str, Union[int, str]]: - """Returns a dict of the subnet info.""" - return self.__dict__ + def _to_parameter_dict( + self, return_type: str + ) -> Union[dict[str, Union[int, str]], "torch.nn.ParameterDict"]: + if return_type == "torch": + return torch.nn.ParameterDict(self.__dict__) + else: + return self.__dict__ def to_parameter_dict( self, ) -> Union[dict[str, Union[int, str]], "torch.nn.ParameterDict"]: + """Returns a torch tensor or dict of the subnet info, depending on the USE_TORCH flag set""" if os.environ.get("USE_TORCH"): - return self._to_parameter_dict_torch() + return self._to_parameter_dict("torch") else: - return self._to_parameter_dict_numpy() + return self._to_parameter_dict("numpy") @classmethod - def _from_parameter_dict_torch( - cls, parameter_dict: "torch.nn.ParameterDict" + def _from_parameter_dict( + cls, + parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"], + return_type: str, ) -> "AxonInfo": - """Returns an axon_info object from a torch parameter_dict.""" - return cls(**dict(parameter_dict)) - - @classmethod - def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "AxonInfo": - r"""Returns an axon_info object from a parameter_dict.""" - return cls(**parameter_dict) + if return_type == "torch": + return cls(**dict(parameter_dict)) + else: + return cls(**parameter_dict) @classmethod def from_parameter_dict( cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] ) -> "AxonInfo": + """Returns an axon_info object from a torch parameter_dict or a parameter dict.""" if os.environ.get("USE_TORCH"): - return cls._from_parameter_dict_torch(parameter_dict) + return cls._from_parameter_dict(parameter_dict, "torch") else: - return cls._from_parameter_dict_numpy(parameter_dict) + return cls._from_parameter_dict(parameter_dict, "numpy") class ChainDataType(Enum): @@ -1008,19 +1009,20 @@ def fix_decoded_values(cls, decoded: Dict) -> "SubnetInfo": owner_ss58=ss58_encode(decoded["owner"], bittensor.__ss58_format__), ) - def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": - """Returns a torch tensor of the subnet info.""" - return maybe_get_torch().nn.ParameterDict(self.__dict__) - - def _to_parameter_dict_numpy(self) -> dict[str, Any]: - """Returns a dict of the subnet info.""" - return self.__dict__ + def _to_parameter_dict( + self, return_type: str + ) -> Union[dict[str, Any], "torch.nn.ParameterDict"]: + if return_type == "torch": + return torch.nn.ParameterDict(self.__dict__) + else: + return self.__dict__ def to_parameter_dict(self) -> Union[dict[str, Any], "torch.nn.ParameterDict"]: + """Returns a torch tensor or dict of the subnet info.""" if os.environ.get("USE_TORCH"): - return self._to_parameter_dict_torch() + return self._to_parameter_dict("torch") else: - return self._to_parameter_dict_numpy() + return self._to_parameter_dict("numpy") @classmethod def _from_parameter_dict_torch( @@ -1128,21 +1130,22 @@ def fix_decoded_values(cls, decoded: Dict) -> "SubnetHyperparameters": difficulty=decoded["difficulty"], ) - def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": - """Returns a torch tensor of the subnet hyperparameters.""" - return maybe_get_torch().nn.ParameterDict(self.__dict__) - - def _to_parameter_dict_numpy(self) -> dict[str, Union[int, float, bool]]: - """Returns a dict of the subnet hyperparameters.""" - return self.__dict__ + def _to_parameter_dict_torch( + self, return_type: str + ) -> Union[dict[str, Union[int, float, bool]], "torch.nn.ParameterDict"]: + if return_type == "torch": + return torch.nn.ParameterDict(self.__dict__) + else: + return self.__dict__ def to_parameter_dict( self, ) -> Union[dict[str, Union[int, float, bool]], "torch.nn.ParameterDict"]: + """Returns a torch tensor or dict of the subnet hyperparameters.""" if os.environ.get("USE_TORCH"): - return self._to_parameter_dict_torch() + return self._to_parameter_dict_torch("torch") else: - return self._to_parameter_dict_numpy() + return self._to_parameter_dict_torch("numpy") @classmethod def _from_parameter_dict_torch( @@ -1221,21 +1224,23 @@ def fix_decoded_values(cls, decoded: Dict) -> "IPInfo": protocol=decoded["ip_type_and_protocol"] & 0xF, ) - def _to_parameter_dict_torch(self) -> "torch.nn.ParameterDict": + def _to_parameter_dict( + self, return_type: str + ) -> Union[dict[str, Union[str, int]], "torch.nn.ParameterDict"]: """Returns a torch tensor of the subnet info.""" - return maybe_get_torch().nn.ParameterDict(self.__dict__) - - def _to_parameter_dict_numpy(self) -> dict[str, Union[str, int]]: - """Returns a dict of the subnet ip info.""" - return self.__dict__ + if return_type == "torch": + return torch.nn.ParameterDict(self.__dict__) + else: + return self.__dict__ def to_parameter_dict( self, ) -> Union[dict[str, Union[str, int]], "torch.nn.ParameterDict"]: + """Returns a torch tensor or dict of the subnet IP info.""" if os.environ.get("USE_TORCH"): - return self._to_parameter_dict_torch() + return self._to_parameter_dict("torch") else: - return self._to_parameter_dict_numpy() + return self._to_parameter_dict("numpy") @classmethod def _from_parameter_dict_torch( diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 2040458ccd..51cd38f472 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -26,7 +26,7 @@ import aiohttp import bittensor from typing import Union, Optional, List, Union, AsyncGenerator, Any -from .utils import maybe_get_torch +from .utils import torch class DendriteMixin: @@ -811,9 +811,9 @@ def __del__(self): if os.environ.get("USE_TORCH"): - class dendrite(maybe_get_torch().nn.module, DendriteMixin): + class dendrite(torch.nn.Module, DendriteMixin): def __init__(self): - maybe_get_torch().nn.module.__init__(self) + torch.nn.Module.__init__(self) DendriteMixin.__init__(self) else: diff --git a/bittensor/extrinsics/registration.py b/bittensor/extrinsics/registration.py index 367d987b5e..50dc33a9f5 100644 --- a/bittensor/extrinsics/registration.py +++ b/bittensor/extrinsics/registration.py @@ -21,7 +21,7 @@ import time from rich.prompt import Confirm from typing import List, Union, Optional, Tuple -from bittensor.utils.registration import POWSolution, create_pow, maybe_get_torch +from bittensor.utils.registration import POWSolution, create_pow, torch def register_extrinsic( @@ -101,8 +101,7 @@ def register_extrinsic( ): return False - torch = maybe_get_torch() - if torch is None: + if not torch: return False # Attempt rolling registration. @@ -382,8 +381,7 @@ def run_faucet_extrinsic( ): return False - torch = maybe_get_torch() - if torch is None: + if not torch: return False # Unlock coldkey diff --git a/bittensor/metagraph.py b/bittensor/metagraph.py index e4a7ff21e3..abb5462348 100644 --- a/bittensor/metagraph.py +++ b/bittensor/metagraph.py @@ -27,7 +27,7 @@ from typing import List, Optional from bittensor.chain_data import AxonInfo -from bittensor.utils.registration import maybe_get_torch +from bittensor.utils.registration import torch METAGRAPH_STATE_DICT_NDARRAY_KEYS = [ "version", @@ -873,22 +873,17 @@ def load_from_path(self, dir_path: str) -> "metagraph": bittensor.__console__.print( "Unable to load file. Attempting to restore metagraph using torch." ) - if not (torch := maybe_get_torch()): - raise ImportError - else: - bittensor.__console__.print( - ":warning:[yellow]Warning:[/yellow] This functionality exists to load " - "metagraph state from legacy saves, but will not be supported in the future." - ) - try: - state_dict = torch.load(graph_filename) - for key in METAGRAPH_STATE_DICT_NDARRAY_KEYS: - state_dict[key] = state_dict[key].detach().numpy() - except RuntimeError: - bittensor.__console__.print( - "Unable to load file. It may be corrupted." - ) - raise + bittensor.__console__.print( + ":warning:[yellow]Warning:[/yellow] This functionality exists to load " + "metagraph state from legacy saves, but will not be supported in the future." + ) + try: + state_dict = torch.load(graph_filename) + for key in METAGRAPH_STATE_DICT_NDARRAY_KEYS: + state_dict[key] = state_dict[key].detach().numpy() + except RuntimeError: + bittensor.__console__.print("Unable to load file. It may be corrupted.") + raise self.n = state_dict["n"] self.block = state_dict["block"] diff --git a/bittensor/utils/__init__.py b/bittensor/utils/__init__.py index 0a8546386a..5450a6eed3 100644 --- a/bittensor/utils/__init__.py +++ b/bittensor/utils/__init__.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # Copyright © 2022 Opentensor Foundation # Copyright © 2023 Opentensor Technologies Inc +import os # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation @@ -20,14 +21,13 @@ import bittensor import hashlib -import torch import requests import scalecodec import numpy as np from .wallet_utils import * # noqa F401 from .version import version_checking, check_version, VersionCheckError -from .registration import maybe_get_torch +from .registration import torch RAOPERTAO = 1e9 U16_MAX = 65535 @@ -40,22 +40,9 @@ def ss58_to_vec_u8(ss58_address: str) -> List[int]: return encoded_address -def unbiased_topk( +def _unbiased_topk( values: np.ndarray, k: int, dim=0, sorted=True, largest=True, axis=0 ) -> Tuple[np.ndarray, np.ndarray]: - r"""Selects topk as in torch.topk but does not bias lower indices when values are equal. - Args: - values: (np.ndarray) - Values to index into. - k: (int): - Number to take. - - Return: - topk: (np.ndarray): - topk k values. - indices: (np.ndarray) - indices of the topk values. - """ if dim != 0 and axis == 0: # Ensures a seamless transition for calls made to this function that specified args by keyword axis = dim @@ -71,6 +58,38 @@ def unbiased_topk( return topk, permutation[indices] +def unbiased_topk( + values: "torch.Tensor", + k: int, + dim: int = 0, + sorted: bool = True, + largest: bool = True, + axis: int = 0, +) -> Union[Tuple[np.ndarray, np.ndarray], Tuple["torch.Tensor", "torch.LongTensor"]]: + r"""Selects topk as in torch.topk but does not bias lower indices when values are equal. + Args: + values: (np.ndarray) if using numpy, (torch.Tensor) if using torch: + Values to index into. + k: (int): + Number to take. + + Return: + topk: (np.ndarray) if using numpy, (torch.Tensor) if using torch: + topk k values. + indices: (np.ndarray) if using numpy, (torch.LongTensor) if using torch: + indices of the topk values. + """ + if os.getenv("USE_TORCH"): + permutation = torch.randperm(values.shape[dim]) + permuted_values = values[permutation] + topk, indices = torch.topk( + permuted_values, k, dim=dim, sorted=sorted, largest=largest + ) + return topk, permutation[indices] + else: + return _unbiased_topk(values, k, dim, sorted, largest, axis) + + def strtobool_with_default( default: bool, ) -> Callable[[str], Union[bool, Literal["==SUPRESS=="]]]: diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py index 3fe4b03a89..aac4c7f286 100644 --- a/bittensor/utils/registration.py +++ b/bittensor/utils/registration.py @@ -25,12 +25,27 @@ torch = None -def maybe_get_torch(): - if torch is None: +class Torch: + @staticmethod + def _error(): bittensor.logging.warning( "This command requires torch. Please install torch package." ) - return torch + raise ImportError + + def __bool__(self): + self._error() + return False + + def __getattr__(self, *_): + self._error() + + def __call__(self, *_): + self._error() + + +if not torch or not os.getenv("USE_TORCH"): + torch = Torch() class CUDAException(Exception): From 1f2989ebac90ad156a2aa7e1817ee56250077abb Mon Sep 17 00:00:00 2001 From: bhimes Date: Mon, 20 May 2024 20:15:09 +0200 Subject: [PATCH 04/18] [WIP] Check-in --- bittensor/commands/root.py | 10 ++-- bittensor/commands/utils.py | 8 ++-- bittensor/extrinsics/root.py | 40 +++++++++++----- bittensor/subtensor.py | 18 +++---- bittensor/tensor.py | 93 ++++++++++++++++++++++++------------ bittensor/utils/__init__.py | 57 +++++++++++++--------- 6 files changed, 144 insertions(+), 82 deletions(-) diff --git a/bittensor/commands/root.py b/bittensor/commands/root.py index f5785f3ebb..003d22fe4f 100644 --- a/bittensor/commands/root.py +++ b/bittensor/commands/root.py @@ -24,7 +24,7 @@ from typing import List, Optional, Dict from rich.prompt import Prompt from rich.table import Table -from .utils import get_delegates_details, DelegatesDetails, maybe_get_torch +from .utils import get_delegates_details, DelegatesDetails, torch from . import defaults @@ -302,7 +302,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): ) my_weights[cli.config.netuid] = new_weight all_netuids = ( - maybe_get_torch().tensor(list(range(len(my_weights)))) + torch.tensor(list(range(len(my_weights)))) if os.environ.get("USE_TORCH") else np.arange(len(my_weights)) ) @@ -424,7 +424,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): my_weights[cli.config.netuid] -= cli.config.amount my_weights[my_weights < 0] = 0 # Ensure weights don't go negative all_netuids = ( - maybe_get_torch().tensor(list(range(len(my_weights)))) + torch.tensor(list(range(len(my_weights)))) if os.environ.get("USE_TORCH") else np.arange(len(my_weights)) ) @@ -530,14 +530,14 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): # Parse from string matched_netuids = list(map(int, re.split(r"[ ,]+", cli.config.netuids))) netuids = ( - maybe_get_torch().tensor(matched_netuids, dtype=maybe_get_torch().long) + torch.tensor(matched_netuids, dtype=torch.long) if os.environ.get("USE_TORCH") else np.array(matched_netuids, dtype=np.int64) ) matched_weights = list(map(float, re.split(r"[ ,]+", cli.config.weights))) weights = ( - maybe_get_torch().tensor(matched_weights, dtype=maybe_get_torch().float32) + torch.tensor(matched_weights, dtype=torch.float32) if os.environ.get("USE_TORCH") else np.array( matched_weights, diff --git a/bittensor/commands/utils.py b/bittensor/commands/utils.py index 3fe4db52ef..4ea8fa3dd1 100644 --- a/bittensor/commands/utils.py +++ b/bittensor/commands/utils.py @@ -19,7 +19,7 @@ import os import bittensor import requests -from bittensor.utils.registration import maybe_get_torch +from bittensor.utils.registration import torch from typing import List, Dict, Any, Optional from rich.prompt import Confirm, PromptBase from dataclasses import dataclass @@ -78,11 +78,9 @@ def check_netuid_set( def check_for_cuda_reg_config(config: "bittensor.config") -> None: """Checks, when CUDA is available, if the user would like to register with their CUDA device.""" - - torch = maybe_get_torch() - if torch is not None and torch.cuda.is_available(): + if torch and torch.cuda.is_available(): if not config.no_prompt: - if config.pow_register.cuda.get("use_cuda") == None: # flag not set + if config.pow_register.cuda.get("use_cuda") is None: # flag not set # Ask about cuda registration only if a CUDA device is available. cuda = Confirm.ask("Detected CUDA device, use CUDA for registration?\n") config.pow_register.cuda.use_cuda = cuda diff --git a/bittensor/extrinsics/root.py b/bittensor/extrinsics/root.py index 826bdf7973..9625a37e31 100644 --- a/bittensor/extrinsics/root.py +++ b/bittensor/extrinsics/root.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # Copyright © 2021 Yuma Rao # Copyright © 2023 Opentensor Foundation +import os # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation @@ -23,9 +24,10 @@ import numpy as np from numpy.typing import NDArray from rich.prompt import Confirm -from typing import Union +from typing import Union, List import bittensor.utils.weight_utils as weight_utils from bittensor.btlogging.defines import BITTENSOR_LOGGER_NAME +from bittensor.utils import torch logger = logging.getLogger(BITTENSOR_LOGGER_NAME) @@ -102,8 +104,8 @@ def root_register_extrinsic( def set_root_weights_extrinsic( subtensor: "bittensor.subtensor", wallet: "bittensor.wallet", - netuids: Union[NDArray[np.int64], list], - weights: Union[NDArray[np.float32], list], + netuids: Union[NDArray[np.int64], "torch.LongTensor", List[int]], + weights: Union[NDArray[np.float32], "torch.FloatTensor", List[float]], version_key: int = 0, wait_for_inclusion: bool = False, wait_for_finalization: bool = False, @@ -114,9 +116,9 @@ def set_root_weights_extrinsic( Args: wallet (bittensor.wallet): Bittensor wallet object. - netuids (List[int]): + netuids (Union[NDArray[np.int64], torch.LongTensor, List[int]]): The ``netuid`` of the subnet to set weights for. - weights (Union[NDArray[np.float32], list]): + weights (Union[NDArray[np.float32], torch.FloatTensor, list]): Weights to set. These must be ``float`` s and must correspond to the passed ``netuid`` s. version_key (int): The version key of the validator. @@ -132,22 +134,36 @@ def set_root_weights_extrinsic( """ # First convert types. if isinstance(netuids, list): - netuids = np.array(netuids, dtype=np.int64) + netuids = ( + torch.tensor(netuids, dtype=torch.int64) + if os.getenv("USE_TORCH") + else np.array(netuids, dtype=np.int64) + ) if isinstance(weights, list): - weights = np.array(weights, dtype=np.float32) + weights = ( + torch.tensor(weights, dtype=torch.float32) + if os.getenv("USE_TORCH") + else np.array(weights, dtype=np.float32) + ) # Get weight restrictions. min_allowed_weights = subtensor.min_allowed_weights(netuid=0) max_weight_limit = subtensor.max_weight_limit(netuid=0) # Get non zero values. - non_zero_weight_idx = np.argwhere(weights > 0).squeeze(axis=1) - non_zero_weight_uids = netuids[non_zero_weight_idx] + non_zero_weight_idx = ( + torch.argwhere(weights > 0).squeeze(dim=1) + if os.getenv("USE_TORCH") + else np.argwhere(weights > 0).squeeze(axis=1) + ) non_zero_weights = weights[non_zero_weight_idx] - if non_zero_weights.size < min_allowed_weights: + non_zero_weights_size = ( + non_zero_weights.numel() if os.getenv("USE_TORCH") else non_zero_weights.size + ) + if non_zero_weights_size < min_allowed_weights: raise ValueError( "The minimum number of weights required to set weights is {}, got {}".format( - min_allowed_weights, non_zero_weights.size + min_allowed_weights, non_zero_weights_size ) ) @@ -192,7 +208,7 @@ def set_root_weights_extrinsic( if not wait_for_finalization and not wait_for_inclusion: return True - if success == True: + if success is True: bittensor.__console__.print( ":white_heavy_check_mark: [green]Finalized[/green]" ) diff --git a/bittensor/subtensor.py b/bittensor/subtensor.py index be40a818c6..7df6d7223f 100644 --- a/bittensor/subtensor.py +++ b/bittensor/subtensor.py @@ -35,6 +35,8 @@ from scalecodec.type_registry import load_type_registry_preset from scalecodec.types import GenericCall +from bittensor.utils import torch + # Local imports. from .chain_data import ( NeuronInfo, @@ -671,8 +673,8 @@ def set_weights( self, wallet: "bittensor.wallet", netuid: int, - uids: Union[NDArray[np.int64], list], - weights: Union[NDArray[np.float32], list], + uids: Union[NDArray[np.int64], "torch.LongTensor", list], + weights: Union[NDArray[np.float32], "torch.FloatTensor", list], version_key: int = bittensor.__version_as_int__, uid: Optional[int] = None, wait_for_inclusion: bool = False, @@ -689,8 +691,8 @@ def set_weights( wallet (bittensor.wallet): The wallet associated with the neuron setting the weights. netuid (int): The unique identifier of the subnet. uid (int): Unique identifier for the caller on the subnet specified by `netuid`. - uids (Union[NDArray[np.int64], list]): The list of neuron UIDs that the weights are being set for. - weights (Union[NDArray[np.float32], list]): The corresponding weights to be set for each UID. + uids (Union[NDArray[np.int64], torch.LongTensor, list]): The list of neuron UIDs that the weights are being set for. + weights (Union[NDArray[np.float32], torch.FloatTensor, list]): The corresponding weights to be set for each UID. version_key (int, optional): Version key for compatibility with the network. wait_for_inclusion (bool, optional): Waits for the transaction to be included in a block. wait_for_finalization (bool, optional): Waits for the transaction to be finalized on the blockchain. @@ -2151,8 +2153,8 @@ def make_substrate_call_with_retry(): def root_set_weights( self, wallet: "bittensor.wallet", - netuids: Union[NDArray[np.int64], list], - weights: Union[NDArray[np.float32], list], + netuids: Union[NDArray[np.int64], "torch.LongTensor", list], + weights: Union[NDArray[np.float32], "torch.FloatTensor", list], version_key: int = 0, wait_for_inclusion: bool = False, wait_for_finalization: bool = False, @@ -2164,8 +2166,8 @@ def root_set_weights( Args: wallet (bittensor.wallet): The wallet associated with the neuron setting the weights. - netuids (Union[NDArray[np.int64], list]): The list of neuron UIDs for which weights are being set. - weights (Union[NDArray[np.float32], list]): The corresponding weights to be set for each UID. + netuids (Union[NDArray[np.int64], torch.LongTensor, list]): The list of neuron UIDs for which weights are being set. + weights (Union[NDArray[np.float32], torch.FloatTensor, list]): The corresponding weights to be set for each UID. version_key (int, optional): Version key for compatibility with the network. wait_for_inclusion (bool, optional): Waits for the transaction to be included in a block. wait_for_finalization (bool, optional): Waits for the transaction to be finalized on the blockchain. diff --git a/bittensor/tensor.py b/bittensor/tensor.py index 5949554041..c95c43c6d6 100644 --- a/bittensor/tensor.py +++ b/bittensor/tensor.py @@ -1,6 +1,7 @@ # The MIT License (MIT) # Copyright © 2021 Yuma Rao # Copyright © 2022 Opentensor Foundation +import os # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation @@ -22,6 +23,7 @@ import pydantic import msgpack_numpy from typing import Optional, Union, List +from bittensor.utils import torch NUMPY_DTYPES = { "float16": np.float16, @@ -35,16 +37,31 @@ "bool": bool, } - -def cast_dtype(raw: Union[None, np.dtype, str]) -> str: +if os.getenv("USE_TORCH"): + TORCH_DTYPES = { + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.uint8": torch.uint8, + "torch.int16": torch.int16, + "torch.int8": torch.int8, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.bool": torch.bool, + } + + +def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str: """ - Casts the raw value to a string representing the `numpy data type `_. + Casts the raw value to a string representing the + `numpy data type `_, or the + `torch data type `_ if using torch. Args: - raw (Union[None, numpy.dtype, str]): The raw value to cast. + raw (Union[None, numpy.dtype, torch.dtype, str]): The raw value to cast. Returns: - str: The string representing the numpy data type. + str: The string representing the numpy/torch data type. Raises: Exception: If the raw value is of an invalid type. @@ -53,14 +70,19 @@ def cast_dtype(raw: Union[None, np.dtype, str]) -> str: return None if isinstance(raw, np.dtype): return NUMPY_DTYPES[raw] + elif os.getenv("USE_TORCH"): + if isinstance(raw, torch.dtype): + return TORCH_DTYPES[raw] elif isinstance(raw, str): - assert ( - raw in NUMPY_DTYPES - ), f"{str} not a valid numpy type in dict {NUMPY_DTYPES}" - return raw + if os.getenv("USE_TORCH"): + assert raw in TORCH_DTYPES, f"{raw} not a valid torch type in dict {TORCH_DTYPES}" + return raw + else: + assert raw in NUMPY_DTYPES, f"{raw} not a valid numpy type in dict {NUMPY_DTYPES}" + return raw else: raise Exception( - f"{raw} of type {type(raw)} does not have a valid type in Union[None, numpy.dtype, str]" + f"{raw} of type {type(raw)} does not have a valid type in Union[None, numpy.dtype, torch.dtype, str]" ) @@ -96,11 +118,9 @@ def cast_shape(raw: Union[None, List[int], str]) -> str: class tensor: - def __new__(cls, tensor: Union[list, np.ndarray, np.ndarray]): - if isinstance(tensor, list): - tensor = np.array(tensor) - elif isinstance(tensor, np.ndarray): - tensor = np.array(tensor) + def __new__(cls, tensor: Union[list, np.ndarray, "torch.Tensor"]): + if isinstance(tensor, list) or isinstance(tensor, np.ndarray): + tensor = torch.tensor(tensor) if os.getenv("USE_TORCH") else np.array(tensor) return Tensor.serialize(tensor=tensor) @@ -117,21 +137,21 @@ class Tensor(pydantic.BaseModel): class Config: validate_assignment = True - def tensor(self) -> np.ndarray: + def tensor(self) -> Union[np.ndarray, "torch.Tensor"]: return self.deserialize() def tolist(self) -> List[object]: return self.deserialize().tolist() def numpy(self) -> "numpy.ndarray": - return self.deserialize() + return self.deserialize().detach().numpy() if os.getenv("USE_TORCH") else self.deserialize() - def deserialize(self) -> "np.ndarray": + def deserialize(self) -> Union["np.ndarray", "torch.Tensor"]: """ Deserializes the Tensor object. Returns: - np.array: The deserialized tensor object. + np.array or torch.Tensor: The deserialized tensor object. Raises: Exception: If the deserialization process encounters an error. @@ -141,19 +161,25 @@ def deserialize(self) -> "np.ndarray": numpy_object = msgpack.unpackb( buffer_bytes, object_hook=msgpack_numpy.decode ).copy() - numpy = numpy_object - # Reshape does not work for (0) or [0] - if not (len(shape) == 1 and shape[0] == 0): - numpy = numpy.reshape(shape) - return numpy.astype(NUMPY_DTYPES[self.dtype]) + if os.getenv("USE_TORCH"): + torch_object = torch.as_tensor(numpy_object) + # Reshape does not work for (0) or [0] + if not (len(shape) == 1 and shape[0] == 0): + torch_object = torch_object.reshape(shape) + return torch_object.type(TORCH_DTYPES[self.dtype]) + else: + # Reshape does not work for (0) or [0] + if not (len(shape) == 1 and shape[0] == 0): + numpy_object = numpy_object.reshape(shape) + return numpy_object.astype(NUMPY_DTYPES[self.dtype]) @staticmethod - def serialize(tensor: "np.ndarray") -> "Tensor": + def serialize(tensor: Union["np.ndarray", "torch.Tensor"]) -> "Tensor": """ Serializes the given tensor. Args: - tensor (np.array): The tensor to serialize. + tensor (np.array or torch.Tensor): The tensor to serialize. Returns: Tensor: The serialized tensor. @@ -165,9 +191,15 @@ def serialize(tensor: "np.ndarray") -> "Tensor": shape = list(tensor.shape) if len(shape) == 0: shape = [0] - data_buffer = base64.b64encode( - msgpack.packb(tensor, default=msgpack_numpy.encode) - ).decode("utf-8") + if os.getenv("USE_TORCH"): + torch_numpy = tensor.cpu().detach().numpy().copy() + data_buffer = base64.b64encode( + msgpack.packb(torch_numpy, default=msgpack_numpy.encode) + ).decode("utf-8") + else: + data_buffer = base64.b64encode( + msgpack.packb(tensor, default=msgpack_numpy.encode) + ).decode("utf-8") return Tensor(buffer=data_buffer, shape=shape, dtype=dtype) buffer: Optional[str] = pydantic.Field( @@ -180,7 +212,8 @@ def serialize(tensor: "np.ndarray") -> "Tensor": dtype: str = pydantic.Field( title="dtype", - description="Tensor data type. This field specifies the data type of the tensor, such as numpy.float32 or numpy.int64.", + description="Tensor data type. " + "This field specifies the data type of the tensor, such as numpy.float32 or torch.int64.", examples="np.float32", allow_mutation=False, repr=True, diff --git a/bittensor/utils/__init__.py b/bittensor/utils/__init__.py index 5450a6eed3..09cf9828bc 100644 --- a/bittensor/utils/__init__.py +++ b/bittensor/utils/__init__.py @@ -41,25 +41,39 @@ def ss58_to_vec_u8(ss58_address: str) -> List[int]: def _unbiased_topk( - values: np.ndarray, k: int, dim=0, sorted=True, largest=True, axis=0 -) -> Tuple[np.ndarray, np.ndarray]: - if dim != 0 and axis == 0: - # Ensures a seamless transition for calls made to this function that specified args by keyword - axis = dim - - permutation = np.random.permutation(values.shape[axis]) - permuted_values = np.take(values, permutation, axis=axis) - indices = np.argpartition(permuted_values, -k, axis=axis)[-k:] - if not sorted: - indices = np.sort(indices, axis=axis) - if not largest: - indices = indices[::-1] - topk = np.take(permuted_values, indices, axis=axis) - return topk, permutation[indices] + values: Union[np.ndarray, "torch.Tensor"], + k: int, + dim=0, + sorted=True, + largest=True, + axis=0, + return_type: str = "numpy", +) -> Union[Tuple[np.ndarray, np.ndarray], Tuple["torch.Tensor", "torch.LongTensor"]]: + if return_type == "torch": + permutation = torch.randperm(values.shape[dim]) + permuted_values = values[permutation] + topk, indices = torch.topk( + permuted_values, k, dim=dim, sorted=sorted, largest=largest + ) + return topk, permutation[indices] + else: + if dim != 0 and axis == 0: + # Ensures a seamless transition for calls made to this function that specified args by keyword + axis = dim + + permutation = np.random.permutation(values.shape[axis]) + permuted_values = np.take(values, permutation, axis=axis) + indices = np.argpartition(permuted_values, -k, axis=axis)[-k:] + if not sorted: + indices = np.sort(indices, axis=axis) + if not largest: + indices = indices[::-1] + topk = np.take(permuted_values, indices, axis=axis) + return topk, permutation[indices] def unbiased_topk( - values: "torch.Tensor", + values: Union[np.ndarray, "torch.Tensor"], k: int, dim: int = 0, sorted: bool = True, @@ -80,14 +94,13 @@ def unbiased_topk( indices of the topk values. """ if os.getenv("USE_TORCH"): - permutation = torch.randperm(values.shape[dim]) - permuted_values = values[permutation] - topk, indices = torch.topk( - permuted_values, k, dim=dim, sorted=sorted, largest=largest + return _unbiased_topk( + values, k, dim, sorted, largest, axis, return_type="torch" ) - return topk, permutation[indices] else: - return _unbiased_topk(values, k, dim, sorted, largest, axis) + return _unbiased_topk( + values, k, dim, sorted, largest, axis, return_type="numpy" + ) def strtobool_with_default( From 81f2a42bb5426dc5e60104f8002d73da7ad10255 Mon Sep 17 00:00:00 2001 From: bhimes Date: Mon, 20 May 2024 21:02:00 +0200 Subject: [PATCH 05/18] [WIP] Check-in --- bittensor/dendrite.py | 10 +-- bittensor/tensor.py | 20 +++-- bittensor/utils/registration.py | 1 - bittensor/utils/weight_utils.py | 126 ++++++++++++++++++++++++-------- 4 files changed, 115 insertions(+), 42 deletions(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 51cd38f472..9d192d9fd9 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -106,7 +106,7 @@ def __init__( The user's wallet or keypair used for signing messages. Defaults to ``None``, in which case a new :func:`bittensor.wallet().hotkey` is generated and used. """ # Initialize the parent class - super(dendrite, self).__init__() + super(DendriteMixin, self).__init__() # Unique identifier for the instance self.uuid = str(uuid.uuid1()) @@ -812,15 +812,15 @@ def __del__(self): if os.environ.get("USE_TORCH"): class dendrite(torch.nn.Module, DendriteMixin): - def __init__(self): + def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): torch.nn.Module.__init__(self) - DendriteMixin.__init__(self) + DendriteMixin.__init__(self, wallet) else: class dendrite(DendriteMixin): - def __init__(self): - DendriteMixin.__init__(self) + def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): + DendriteMixin.__init__(self, wallet) async def __call__(self, *args, **kwargs): return await self.forward(*args, **kwargs) diff --git a/bittensor/tensor.py b/bittensor/tensor.py index c95c43c6d6..3b4c090845 100644 --- a/bittensor/tensor.py +++ b/bittensor/tensor.py @@ -75,10 +75,14 @@ def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str: return TORCH_DTYPES[raw] elif isinstance(raw, str): if os.getenv("USE_TORCH"): - assert raw in TORCH_DTYPES, f"{raw} not a valid torch type in dict {TORCH_DTYPES}" + assert ( + raw in TORCH_DTYPES + ), f"{raw} not a valid torch type in dict {TORCH_DTYPES}" return raw else: - assert raw in NUMPY_DTYPES, f"{raw} not a valid numpy type in dict {NUMPY_DTYPES}" + assert ( + raw in NUMPY_DTYPES + ), f"{raw} not a valid numpy type in dict {NUMPY_DTYPES}" return raw else: raise Exception( @@ -120,7 +124,9 @@ def cast_shape(raw: Union[None, List[int], str]) -> str: class tensor: def __new__(cls, tensor: Union[list, np.ndarray, "torch.Tensor"]): if isinstance(tensor, list) or isinstance(tensor, np.ndarray): - tensor = torch.tensor(tensor) if os.getenv("USE_TORCH") else np.array(tensor) + tensor = ( + torch.tensor(tensor) if os.getenv("USE_TORCH") else np.array(tensor) + ) return Tensor.serialize(tensor=tensor) @@ -144,7 +150,11 @@ def tolist(self) -> List[object]: return self.deserialize().tolist() def numpy(self) -> "numpy.ndarray": - return self.deserialize().detach().numpy() if os.getenv("USE_TORCH") else self.deserialize() + return ( + self.deserialize().detach().numpy() + if os.getenv("USE_TORCH") + else self.deserialize() + ) def deserialize(self) -> Union["np.ndarray", "torch.Tensor"]: """ @@ -213,7 +223,7 @@ def serialize(tensor: Union["np.ndarray", "torch.Tensor"]) -> "Tensor": dtype: str = pydantic.Field( title="dtype", description="Tensor data type. " - "This field specifies the data type of the tensor, such as numpy.float32 or torch.int64.", + "This field specifies the data type of the tensor, such as numpy.float32 or torch.int64.", examples="np.float32", allow_mutation=False, repr=True, diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py index aac4c7f286..8ae13dd8dd 100644 --- a/bittensor/utils/registration.py +++ b/bittensor/utils/registration.py @@ -34,7 +34,6 @@ def _error(): raise ImportError def __bool__(self): - self._error() return False def __getattr__(self, *_): diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 12c2b903f1..298ca33f75 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -1,4 +1,4 @@ -""" Conversion for weight between chain representation and np.array +""" Conversion for weight between chain representation and np.array or torch.Tensor """ # The MIT License (MIT) @@ -18,18 +18,22 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import os import numpy as np import bittensor from numpy.typing import NDArray -from typing import Tuple, List +from typing import Tuple, List, Union +from bittensor.utils import torch U32_MAX = 4294967295 U16_MAX = 65535 +USE_TORCH = True if os.getenv("USE_TORCH") == 1 else False + def normalize_max_weight( - x: NDArray[np.float32], limit: float = 0.1 -) -> NDArray[np.float32]: + x: Union[NDArray[np.float32], "torch.FloatTensor"], limit: float = 0.1 +) -> Union[NDArray[np.float32], "torch.FloatTensor"]: r"""Normalizes the tensor x so that sum(x) = 1 and the max value is not greater than the limit. Args: x (:obj:`np.float32`): @@ -42,8 +46,14 @@ def normalize_max_weight( """ epsilon = 1e-7 # For numerical stability after normalization - weights = x.copy() - values = np.sort(weights) + weights = x.clone() if USE_TORCH else x.copy() + if USE_TORCH: + values, _ = torch.sort(weights) + else: + values = np.sort(weights) + + if USE_TORCH and x.sum() == 0 or len(x) * limit <= 1: + return torch.ones_like(x) / x.size(0) if x.sum() == 0 or x.shape[0] * limit <= 1: return np.ones_like(x) / x.shape[0] @@ -54,11 +64,16 @@ def normalize_max_weight( return weights / weights.sum() # Find the cumlative sum and sorted tensor - cumsum = np.cumsum(estimation, 0) + cumsum = torch.cumsum(estimation, 0) if USE_TORCH else np.cumsum(estimation, 0) # Determine the index of cutoff - estimation_sum = np.array( - [(len(values) - i - 1) * estimation[i] for i in range(len(values))] + estimation_sum_data = [ + (len(values) - i - 1) * estimation[i] for i in range(len(values)) + ] + estimation_sum = ( + torch.tensor(estimation_sum_data) + if USE_TORCH + else np.array(estimation_sum_data) ) n_values = (estimation / (estimation_sum + cumsum + epsilon) < limit).sum() @@ -78,7 +93,7 @@ def normalize_max_weight( def convert_weight_uids_and_vals_to_tensor( n: int, uids: List[int], weights: List[int] -) -> NDArray[np.float32]: +) -> Union[NDArray[np.float32], "torch.FloatTensor"]: r"""Converts weights and uids from chain representation into a np.array (inverse operation from convert_weights_and_uids_for_emit) Args: n: int: @@ -88,10 +103,14 @@ def convert_weight_uids_and_vals_to_tensor( weights (:obj:`List[int],`): Tensor of weights. Returns: - row_weights ( np.float32 ): + row_weights ( np.float32 or torch.FloatTensor ): Converted row weights. """ - row_weights = np.zeros([n], dtype=np.float32) + row_weights = ( + torch.zeros([n], dtype=torch.float32) + if USE_TORCH + else np.zeros([n], dtype=np.float32) + ) for uid_j, wij in list(zip(uids, weights)): row_weights[uid_j] = float( wij @@ -104,8 +123,8 @@ def convert_weight_uids_and_vals_to_tensor( def convert_root_weight_uids_and_vals_to_tensor( n: int, uids: List[int], weights: List[int], subnets: List[int] -) -> NDArray[np.float32]: - r"""Converts root weights and uids from chain representation into a np.array (inverse operation from convert_weights_and_uids_for_emit) +) -> Union[NDArray[np.float32], "torch.FloatTensor"]: + r"""Converts root weights and uids from chain representation into a np.array or torch FloatTensor (inverse operation from convert_weights_and_uids_for_emit) Args: n: int: number of neurons on network. @@ -120,7 +139,11 @@ def convert_root_weight_uids_and_vals_to_tensor( Converted row weights. """ - row_weights = np.zeros([n], dtype=np.float32) + row_weights = ( + torch.zeros([n], dtype=torch.float32) + if USE_TORCH + else np.zeros([n], dtype=np.float32) + ) for uid_j, wij in list(zip(uids, weights)): if uid_j in subnets: index_s = subnets.index(uid_j) @@ -137,7 +160,7 @@ def convert_root_weight_uids_and_vals_to_tensor( def convert_bond_uids_and_vals_to_tensor( n: int, uids: List[int], bonds: List[int] -) -> NDArray[np.int64]: +) -> Union[NDArray[np.int64], "torch.LongTensor"]: r"""Converts bond and uids from chain representation into a np.array. Args: n: int: @@ -150,14 +173,19 @@ def convert_bond_uids_and_vals_to_tensor( row_bonds ( np.float32 ): Converted row bonds. """ - row_bonds = np.zeros([n], dtype=np.int64) + row_bonds = ( + torch.zeros([n], dtype=torch.int64) + if USE_TORCH + else np.zeros([n], dtype=np.int64) + ) for uid_j, bij in list(zip(uids, bonds)): row_bonds[uid_j] = int(bij) return row_bonds def convert_weights_and_uids_for_emit( - uids: NDArray[np.int64], weights: NDArray[np.float32] + uids: Union[NDArray[np.int64], "torch.LongTensor"], + weights: Union[NDArray[np.float32], "torch.FloatTensor"], ) -> Tuple[List[int], List[int]]: r"""Converts weights into integer u32 representation that sum to MAX_INT_WEIGHT. Args: @@ -210,13 +238,16 @@ def convert_weights_and_uids_for_emit( def process_weights_for_netuid( - uids: NDArray[np.int64], - weights: NDArray[np.float32], + uids: Union[NDArray[np.int64], "torch.Tensor"], + weights: Union[NDArray[np.float32], "torch.Tensor"], netuid: int, subtensor: "bittensor.subtensor", metagraph: "bittensor.metagraph" = None, exclude_quantile: int = 0, -) -> Tuple[NDArray[np.int64], NDArray[np.float32]]: +) -> Union[ + Tuple["torch.Tensor", "torch.FloatTensor"], + Tuple[NDArray[np.int64], NDArray[np.float32]], +]: bittensor.logging.debug("process_weights_for_netuid()") bittensor.logging.debug("weights", weights) bittensor.logging.debug("netuid", netuid) @@ -228,8 +259,12 @@ def process_weights_for_netuid( metagraph = subtensor.metagraph(netuid) # Cast weights to floats. - if not isinstance(weights, np.float32): - weights = weights.astype(np.float32) + if not USE_TORCH: + if not isinstance(weights, torch.FloatTensor): + weights = weights.type(torch.float32) + else: + if not isinstance(weights, np.float32): + weights = weights.astype(np.float32) # Network configuration parameters from an subtensor. # These parameters determine the range of acceptable weights for each neuron. @@ -241,29 +276,54 @@ def process_weights_for_netuid( bittensor.logging.debug("max_weight_limit", max_weight_limit) # Find all non zero weights. - non_zero_weight_idx = np.argwhere(weights > 0).squeeze(axis=1) + non_zero_weight_idx = ( + torch.argwhere(weights > 0).squeeze(dim=1) + if USE_TORCH + else np.argwhere(weights > 0).squeeze(axis=1) + ) non_zero_weight_uids = uids[non_zero_weight_idx] non_zero_weights = weights[non_zero_weight_idx] - if non_zero_weights.size == 0 or metagraph.n < min_allowed_weights: + nzw_size = non_zero_weights.numel() if USE_TORCH else non_zero_weights.size + if nzw_size == 0 or metagraph.n < min_allowed_weights: bittensor.logging.warning("No non-zero weights returning all ones.") - final_weights = np.ones((metagraph.n), dtype=np.int64) / metagraph.n + final_weights = ( + torch.ones((metagraph.n)).to(metagraph.n) / metagraph.n + if USE_TORCH + else np.ones((metagraph.n), dtype=np.int64) / metagraph.n + ) bittensor.logging.debug("final_weights", final_weights) - return np.arange(len(final_weights)), final_weights + final_weights_count = ( + torch.tensor(list(range(len(final_weights)))) + if USE_TORCH + else np.arange(len(final_weights)) + ) + return ( + (final_weights_count, final_weights) + if USE_TORCH + else (final_weights_count, final_weights) + ) - elif non_zero_weights.size < min_allowed_weights: + elif nzw_size < min_allowed_weights: bittensor.logging.warning( "No non-zero weights less then min allowed weight, returning all ones." ) # ( const ): Should this be np.zeros( ( metagraph.n ) ) to reset everyone to build up weight? weights = ( - np.ones((metagraph.n), dtype=np.int64) * 1e-5 + torch.ones((metagraph.n)).to(metagraph.n) * 1e-5 + if USE_TORCH + else np.ones((metagraph.n), dtype=np.int64) * 1e-5 ) # creating minimum even non-zero weights weights[non_zero_weight_idx] += non_zero_weights bittensor.logging.debug("final_weights", weights) normalized_weights = bittensor.utils.weight_utils.normalize_max_weight( x=weights, limit=max_weight_limit ) - return np.arange(len(normalized_weights)), normalized_weights + nw_arange = ( + torch.tensor(list(range(len(normalized_weights)))) + if USE_TORCH + else np.arange(len(normalized_weights)) + ) + return nw_arange, normalized_weights bittensor.logging.debug("non_zero_weights", non_zero_weights) @@ -272,7 +332,11 @@ def process_weights_for_netuid( non_zero_weights ) exclude_quantile = min([quantile, max_exclude]) - lowest_quantile = np.quantile(non_zero_weights, exclude_quantile) + lowest_quantile = ( + non_zero_weights.quantile(exclude_quantile) + if USE_TORCH + else np.quantile(non_zero_weights, exclude_quantile) + ) bittensor.logging.debug("max_exclude", max_exclude) bittensor.logging.debug("exclude_quantile", exclude_quantile) bittensor.logging.debug("lowest_quantile", lowest_quantile) From 0abdc8837bd856d291ea06cae728a20d1de2b394 Mon Sep 17 00:00:00 2001 From: bhimes Date: Mon, 20 May 2024 21:16:18 +0200 Subject: [PATCH 06/18] [WIP] Check-in --- bittensor/dendrite.py | 2 +- bittensor/utils/weight_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 9d192d9fd9..1c90ff489d 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -106,7 +106,7 @@ def __init__( The user's wallet or keypair used for signing messages. Defaults to ``None``, in which case a new :func:`bittensor.wallet().hotkey` is generated and used. """ # Initialize the parent class - super(DendriteMixin, self).__init__() + # super(DendriteMixin, self).__init__() # Unique identifier for the instance self.uuid = str(uuid.uuid1()) diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 298ca33f75..06c16ceac8 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -28,7 +28,7 @@ U32_MAX = 4294967295 U16_MAX = 65535 -USE_TORCH = True if os.getenv("USE_TORCH") == 1 else False +USE_TORCH = True if os.getenv("USE_TORCH") == "1" else False def normalize_max_weight( From a8a48ff5b3a75a07b7f72b53fdebbf0629507555 Mon Sep 17 00:00:00 2001 From: bhimes Date: Mon, 20 May 2024 22:19:47 +0200 Subject: [PATCH 07/18] [WIP] Fixed dendrite Mixin --- bittensor/dendrite.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 1c90ff489d..84e43e24ea 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -28,6 +28,8 @@ from typing import Union, Optional, List, Union, AsyncGenerator, Any from .utils import torch +USE_TORCH = True if os.getenv("USE_TORCH") == "1" else False + class DendriteMixin: """ @@ -106,7 +108,7 @@ def __init__( The user's wallet or keypair used for signing messages. Defaults to ``None``, in which case a new :func:`bittensor.wallet().hotkey` is generated and used. """ # Initialize the parent class - # super(DendriteMixin, self).__init__() + super(DendriteMixin, self).__init__() # Unique identifier for the instance self.uuid = str(uuid.uuid1()) @@ -809,18 +811,14 @@ def __del__(self): self.close_session() -if os.environ.get("USE_TORCH"): - - class dendrite(torch.nn.Module, DendriteMixin): - def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): - torch.nn.Module.__init__(self) - DendriteMixin.__init__(self, wallet) +BaseClass = torch.nn.Module if USE_TORCH else object -else: - class dendrite(DendriteMixin): - def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): - DendriteMixin.__init__(self, wallet) +class dendrite(DendriteMixin, BaseClass): + def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): + if USE_TORCH: + torch.nn.Module.__init__(self) + DendriteMixin.__init__(self, wallet) - async def __call__(self, *args, **kwargs): - return await self.forward(*args, **kwargs) + async def __call__(self, *args, **kwargs): + return await self.forward(*args, **kwargs) From bdb8ec646795ad9237b750f678daddb6d978cea2 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 00:06:31 +0200 Subject: [PATCH 08/18] [WIP] Tests --- bittensor/dendrite.py | 8 ++-- bittensor/extrinsics/registration.py | 6 +-- bittensor/utils/registration.py | 30 +++++++++++++-- bittensor/utils/weight_utils.py | 38 ++++++++++--------- tests/integration_tests/test_cli.py | 5 +++ .../test_metagraph_integration.py | 2 + .../test_subtensor_integration.py | 1 - .../extrinsics/test_registration.py | 5 +++ tests/unit_tests/extrinsics/test_root.py | 8 ++++ 9 files changed, 74 insertions(+), 29 deletions(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 84e43e24ea..6185374481 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -28,7 +28,9 @@ from typing import Union, Optional, List, Union, AsyncGenerator, Any from .utils import torch -USE_TORCH = True if os.getenv("USE_TORCH") == "1" else False + +def use_torch() -> bool: + return True if os.getenv("USE_TORCH") == "1" else False class DendriteMixin: @@ -811,12 +813,12 @@ def __del__(self): self.close_session() -BaseClass = torch.nn.Module if USE_TORCH else object +BaseClass = torch.nn.Module if use_torch() else object class dendrite(DendriteMixin, BaseClass): def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): - if USE_TORCH: + if use_torch(): torch.nn.Module.__init__(self) DendriteMixin.__init__(self, wallet) diff --git a/bittensor/extrinsics/registration.py b/bittensor/extrinsics/registration.py index 50dc33a9f5..5504d8889f 100644 --- a/bittensor/extrinsics/registration.py +++ b/bittensor/extrinsics/registration.py @@ -17,7 +17,7 @@ # DEALINGS IN THE SOFTWARE. import bittensor - +import os import time from rich.prompt import Confirm from typing import List, Union, Optional, Tuple @@ -101,7 +101,7 @@ def register_extrinsic( ): return False - if not torch: + if not os.getenv("USE_TORCH"): return False # Attempt rolling registration. @@ -381,7 +381,7 @@ def run_faucet_extrinsic( ): return False - if not torch: + if not os.getenv("USE_TORCH"): return False # Unlock coldkey diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py index 8ae13dd8dd..bb8c1d99a6 100644 --- a/bittensor/utils/registration.py +++ b/bittensor/utils/registration.py @@ -26,6 +26,10 @@ class Torch: + + def __init__(self): + self._transformed = False + @staticmethod def _error(): bittensor.logging.warning( @@ -33,14 +37,32 @@ def _error(): ) raise ImportError + def _transform(self): + try: + import torch as real_torch + self.__dict__.update(real_torch.__dict__) + self._transformed = True + except ImportError: + self._error() + def __bool__(self): return False - def __getattr__(self, *_): - self._error() + def __getattr__(self, name): + if not self._transformed and os.getenv("USE_TORCH"): + self._transform() + if self._transformed: + return getattr(self, name) + else: + self._error() - def __call__(self, *_): - self._error() + def __call__(self, *args, **kwargs): + if not self._transformed and os.getenv("USE_TORCH"): + self._transform() + if self._transformed: + return self(*args, **kwargs) + else: + self._error() if not torch or not os.getenv("USE_TORCH"): diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 06c16ceac8..01be6b98ee 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -28,7 +28,9 @@ U32_MAX = 4294967295 U16_MAX = 65535 -USE_TORCH = True if os.getenv("USE_TORCH") == "1" else False + +def use_torch() -> bool: + return True if os.getenv("USE_TORCH") == "1" else False def normalize_max_weight( @@ -46,13 +48,13 @@ def normalize_max_weight( """ epsilon = 1e-7 # For numerical stability after normalization - weights = x.clone() if USE_TORCH else x.copy() - if USE_TORCH: + weights = x.clone() if use_torch() else x.copy() + if use_torch(): values, _ = torch.sort(weights) else: values = np.sort(weights) - if USE_TORCH and x.sum() == 0 or len(x) * limit <= 1: + if use_torch() and x.sum() == 0 or len(x) * limit <= 1: return torch.ones_like(x) / x.size(0) if x.sum() == 0 or x.shape[0] * limit <= 1: @@ -64,7 +66,7 @@ def normalize_max_weight( return weights / weights.sum() # Find the cumlative sum and sorted tensor - cumsum = torch.cumsum(estimation, 0) if USE_TORCH else np.cumsum(estimation, 0) + cumsum = torch.cumsum(estimation, 0) if use_torch() else np.cumsum(estimation, 0) # Determine the index of cutoff estimation_sum_data = [ @@ -72,7 +74,7 @@ def normalize_max_weight( ] estimation_sum = ( torch.tensor(estimation_sum_data) - if USE_TORCH + if use_torch() else np.array(estimation_sum_data) ) n_values = (estimation / (estimation_sum + cumsum + epsilon) < limit).sum() @@ -108,7 +110,7 @@ def convert_weight_uids_and_vals_to_tensor( """ row_weights = ( torch.zeros([n], dtype=torch.float32) - if USE_TORCH + if use_torch() else np.zeros([n], dtype=np.float32) ) for uid_j, wij in list(zip(uids, weights)): @@ -141,7 +143,7 @@ def convert_root_weight_uids_and_vals_to_tensor( row_weights = ( torch.zeros([n], dtype=torch.float32) - if USE_TORCH + if use_torch() else np.zeros([n], dtype=np.float32) ) for uid_j, wij in list(zip(uids, weights)): @@ -175,7 +177,7 @@ def convert_bond_uids_and_vals_to_tensor( """ row_bonds = ( torch.zeros([n], dtype=torch.int64) - if USE_TORCH + if use_torch() else np.zeros([n], dtype=np.int64) ) for uid_j, bij in list(zip(uids, bonds)): @@ -259,7 +261,7 @@ def process_weights_for_netuid( metagraph = subtensor.metagraph(netuid) # Cast weights to floats. - if not USE_TORCH: + if not use_torch(): if not isinstance(weights, torch.FloatTensor): weights = weights.type(torch.float32) else: @@ -278,28 +280,28 @@ def process_weights_for_netuid( # Find all non zero weights. non_zero_weight_idx = ( torch.argwhere(weights > 0).squeeze(dim=1) - if USE_TORCH + if use_torch() else np.argwhere(weights > 0).squeeze(axis=1) ) non_zero_weight_uids = uids[non_zero_weight_idx] non_zero_weights = weights[non_zero_weight_idx] - nzw_size = non_zero_weights.numel() if USE_TORCH else non_zero_weights.size + nzw_size = non_zero_weights.numel() if use_torch() else non_zero_weights.size if nzw_size == 0 or metagraph.n < min_allowed_weights: bittensor.logging.warning("No non-zero weights returning all ones.") final_weights = ( torch.ones((metagraph.n)).to(metagraph.n) / metagraph.n - if USE_TORCH + if use_torch() else np.ones((metagraph.n), dtype=np.int64) / metagraph.n ) bittensor.logging.debug("final_weights", final_weights) final_weights_count = ( torch.tensor(list(range(len(final_weights)))) - if USE_TORCH + if use_torch() else np.arange(len(final_weights)) ) return ( (final_weights_count, final_weights) - if USE_TORCH + if use_torch() else (final_weights_count, final_weights) ) @@ -310,7 +312,7 @@ def process_weights_for_netuid( # ( const ): Should this be np.zeros( ( metagraph.n ) ) to reset everyone to build up weight? weights = ( torch.ones((metagraph.n)).to(metagraph.n) * 1e-5 - if USE_TORCH + if use_torch() else np.ones((metagraph.n), dtype=np.int64) * 1e-5 ) # creating minimum even non-zero weights weights[non_zero_weight_idx] += non_zero_weights @@ -320,7 +322,7 @@ def process_weights_for_netuid( ) nw_arange = ( torch.tensor(list(range(len(normalized_weights)))) - if USE_TORCH + if use_torch() else np.arange(len(normalized_weights)) ) return nw_arange, normalized_weights @@ -334,7 +336,7 @@ def process_weights_for_netuid( exclude_quantile = min([quantile, max_exclude]) lowest_quantile = ( non_zero_weights.quantile(exclude_quantile) - if USE_TORCH + if use_torch() else np.quantile(non_zero_weights, exclude_quantile) ) bittensor.logging.debug("max_exclude", max_exclude) diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index 66e4726c1a..acedbba472 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -2088,6 +2088,10 @@ def test_register(self, _): self.assertTrue(registered) def test_pow_register(self, _): + # Not the best way to do this, but I need to finish these tests, and unittest doesn't make this + # as simple as pytest + import os + os.environ["USE_TORCH"] = "1" config = self.config config.command = "subnets" config.subcommand = "pow_register" @@ -2111,6 +2115,7 @@ class MockException(Exception): mock_create_wallet.assert_called_once() self.assertEqual(mock_is_stale.call_count, 1) + os.unsetenv("USE_TORCH") def test_stake(self, _): amount_to_stake: Balance = Balance.from_tao(0.5) diff --git a/tests/integration_tests/test_metagraph_integration.py b/tests/integration_tests/test_metagraph_integration.py index 5dbb9ddfc1..82a0bc8878 100644 --- a/tests/integration_tests/test_metagraph_integration.py +++ b/tests/integration_tests/test_metagraph_integration.py @@ -58,6 +58,7 @@ def test_load_sync_save(self): self.metagraph.save() def test_load_sync_save_from_torch(self): + os.environ["USE_TORCH"] = "1" self.metagraph.sync(lite=True, subtensor=self.sub) def deprecated_save_torch(metagraph): @@ -73,6 +74,7 @@ def deprecated_save_torch(metagraph): deprecated_save_torch(self.metagraph) self.metagraph.load() + os.unsetenv("USE_TORCH") def test_state_dict(self): self.metagraph.load() diff --git a/tests/integration_tests/test_subtensor_integration.py b/tests/integration_tests/test_subtensor_integration.py index b4b6e905e5..a54fe93468 100644 --- a/tests/integration_tests/test_subtensor_integration.py +++ b/tests/integration_tests/test_subtensor_integration.py @@ -16,7 +16,6 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. - import random import socket import unittest diff --git a/tests/unit_tests/extrinsics/test_registration.py b/tests/unit_tests/extrinsics/test_registration.py index f20d037b96..9c065207b8 100644 --- a/tests/unit_tests/extrinsics/test_registration.py +++ b/tests/unit_tests/extrinsics/test_registration.py @@ -50,6 +50,11 @@ def mock_new_wallet(): return mock +@pytest.fixture(autouse=True) +def set_use_torch_env(monkeypatch): + monkeypatch.setenv("USE_TORCH", "1") + + @pytest.mark.parametrize( "wait_for_inclusion,wait_for_finalization,prompt,cuda,dev_id,tpb,num_processes,update_interval,log_verbose,expected", [ diff --git a/tests/unit_tests/extrinsics/test_root.py b/tests/unit_tests/extrinsics/test_root.py index 4806a022a8..08b9e3fa44 100644 --- a/tests/unit_tests/extrinsics/test_root.py +++ b/tests/unit_tests/extrinsics/test_root.py @@ -1,3 +1,4 @@ +import os import pytest from unittest.mock import MagicMock, patch from bittensor.subtensor import subtensor as Subtensor @@ -21,6 +22,11 @@ def mock_wallet(): return mock +@pytest.fixture(autouse=True) +def set_use_torch_env(monkeypatch): + monkeypatch.setenv("USE_TORCH", "1") + + @pytest.mark.parametrize( "wait_for_inclusion, wait_for_finalization, hotkey_registered, registration_success, prompt, user_response, expected_result", [ @@ -70,6 +76,8 @@ def mock_wallet(): "failure-prompt-declined", ], ) + + def test_root_register_extrinsic( mock_subtensor, mock_wallet, From e54301fad8dbe318f9376fd66ba5e7fac605fd41 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 11:59:39 +0200 Subject: [PATCH 09/18] [WIP] Fixed tests. --- bittensor/dendrite.py | 24 ++++++++++++------- bittensor/utils/registration.py | 2 +- bittensor/utils/weight_utils.py | 4 +++- tests/integration_tests/test_cli.py | 3 +-- .../test_metagraph_integration.py | 2 +- .../test_subtensor_integration.py | 11 ++++++++- tests/unit_tests/extrinsics/test_root.py | 2 -- 7 files changed, 32 insertions(+), 16 deletions(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 6185374481..0dd34385e5 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -813,14 +813,22 @@ def __del__(self): self.close_session() -BaseClass = torch.nn.Module if use_torch() else object +if use_torch(): - -class dendrite(DendriteMixin, BaseClass): - def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): - if use_torch(): + class dendrite(DendriteMixin, torch.nn.Module): + def __init__( + self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None + ): torch.nn.Module.__init__(self) - DendriteMixin.__init__(self, wallet) + DendriteMixin.__init__(self, wallet) + +else: + + class dendrite(DendriteMixin): + def __init__( + self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None + ): + DendriteMixin.__init__(self, wallet) - async def __call__(self, *args, **kwargs): - return await self.forward(*args, **kwargs) + async def __call__(self, *args, **kwargs): + return await self.forward(*args, **kwargs) diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py index bb8c1d99a6..69f38f44db 100644 --- a/bittensor/utils/registration.py +++ b/bittensor/utils/registration.py @@ -26,7 +26,6 @@ class Torch: - def __init__(self): self._transformed = False @@ -40,6 +39,7 @@ def _error(): def _transform(self): try: import torch as real_torch + self.__dict__.update(real_torch.__dict__) self._transformed = True except ImportError: diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 01be6b98ee..3efe2310dc 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -66,7 +66,9 @@ def normalize_max_weight( return weights / weights.sum() # Find the cumlative sum and sorted tensor - cumsum = torch.cumsum(estimation, 0) if use_torch() else np.cumsum(estimation, 0) + cumsum = ( + torch.cumsum(estimation, 0) if use_torch() else np.cumsum(estimation, 0) + ) # Determine the index of cutoff estimation_sum_data = [ diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index acedbba472..a449604a80 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -2090,7 +2090,6 @@ def test_register(self, _): def test_pow_register(self, _): # Not the best way to do this, but I need to finish these tests, and unittest doesn't make this # as simple as pytest - import os os.environ["USE_TORCH"] = "1" config = self.config config.command = "subnets" @@ -2115,7 +2114,7 @@ class MockException(Exception): mock_create_wallet.assert_called_once() self.assertEqual(mock_is_stale.call_count, 1) - os.unsetenv("USE_TORCH") + del os.environ["USE_TORCH"] def test_stake(self, _): amount_to_stake: Balance = Balance.from_tao(0.5) diff --git a/tests/integration_tests/test_metagraph_integration.py b/tests/integration_tests/test_metagraph_integration.py index 82a0bc8878..b2231e3e92 100644 --- a/tests/integration_tests/test_metagraph_integration.py +++ b/tests/integration_tests/test_metagraph_integration.py @@ -74,7 +74,7 @@ def deprecated_save_torch(metagraph): deprecated_save_torch(self.metagraph) self.metagraph.load() - os.unsetenv("USE_TORCH") + del os.environ["USE_TORCH"] def test_state_dict(self): self.metagraph.load() diff --git a/tests/integration_tests/test_subtensor_integration.py b/tests/integration_tests/test_subtensor_integration.py index a54fe93468..5c2ff0cf34 100644 --- a/tests/integration_tests/test_subtensor_integration.py +++ b/tests/integration_tests/test_subtensor_integration.py @@ -18,6 +18,7 @@ import random import socket +import os import unittest from queue import Empty as QueueEmpty from unittest.mock import MagicMock, patch @@ -422,6 +423,7 @@ def test_is_hotkey_registered_not_registered(self): self.assertFalse(registered, msg="Hotkey should not be registered") def test_registration_multiprocessed_already_registered(self): + os.environ["USE_TORCH"] = "1" workblocks_before_is_registered = random.randint(5, 10) # return False each work block but return True after a random number of blocks is_registered_return_values = ( @@ -475,8 +477,10 @@ def test_registration_multiprocessed_already_registered(self): self.subtensor.is_hotkey_registered.call_count == workblocks_before_is_registered + 2 ) + del os.environ["USE_TORCH"] def test_registration_partly_failed(self): + os.environ["USE_TORCH"] = "1" do_pow_register_mock = MagicMock( side_effect=[(False, "Failed"), (False, "Failed"), (True, None)] ) @@ -510,8 +514,10 @@ def is_registered_side_effect(*args, **kwargs): ), msg="Registration should succeed", ) + del os.environ["USE_TORCH"] def test_registration_failed(self): + os.environ["USE_TORCH"] = "1" is_registered_return_values = [False for _ in range(100)] current_block = [i for i in range(0, 100)] mock_neuron = MagicMock() @@ -545,9 +551,11 @@ def test_registration_failed(self): msg="Registration should fail", ) self.assertEqual(mock_create_pow.call_count, 3) + del os.environ["USE_TORCH"] def test_registration_stale_then_continue(self): - # verifty that after a stale solution, the solve will continue without exiting + # verify that after a stale solution, the solve will continue without exiting + os.environ["USE_TORCH"] = "1" class ExitEarly(Exception): pass @@ -588,6 +596,7 @@ class ExitEarly(Exception): 1, msg="only tries to submit once, then exits", ) + del os.environ["USE_TORCH"] def test_defaults_to_finney(self): sub = bittensor.subtensor() diff --git a/tests/unit_tests/extrinsics/test_root.py b/tests/unit_tests/extrinsics/test_root.py index 08b9e3fa44..a8de9e21c1 100644 --- a/tests/unit_tests/extrinsics/test_root.py +++ b/tests/unit_tests/extrinsics/test_root.py @@ -76,8 +76,6 @@ def set_use_torch_env(monkeypatch): "failure-prompt-declined", ], ) - - def test_root_register_extrinsic( mock_subtensor, mock_wallet, From e12fc6b8b58f02c2ab6ff63f59d510726657a341 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 15:00:31 +0200 Subject: [PATCH 10/18] [WIP] Dendrite fix for mypy redeclaration. --- bittensor/dendrite.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 0dd34385e5..4c31553c48 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -813,22 +813,19 @@ def __del__(self): self.close_session() -if use_torch(): +# For back-compatibility with torch +BaseModel = torch.nn.Module if use_torch() else object - class dendrite(DendriteMixin, torch.nn.Module): - def __init__( - self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None - ): + +class dendrite(DendriteMixin, BaseModel): + def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): + if use_torch(): torch.nn.Module.__init__(self) - DendriteMixin.__init__(self, wallet) + DendriteMixin.__init__(self, wallet) -else: - class dendrite(DendriteMixin): - def __init__( - self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None - ): - DendriteMixin.__init__(self, wallet) +if not use_torch(): + async def call(self, *args, **kwargs): + return await self.forward(*args, **kwargs) - async def __call__(self, *args, **kwargs): - return await self.forward(*args, **kwargs) + dendrite.__call__ = call From 7fa6746fdee70c40a5d499df13b3979201dc6008 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 17:57:28 +0200 Subject: [PATCH 11/18] [WIP] updates --- bittensor/extrinsics/set_weights.py | 25 +- bittensor/metagraph.py | 495 +++++++++++++----- .../test_metagraph_integration.py | 4 +- 3 files changed, 379 insertions(+), 145 deletions(-) diff --git a/bittensor/extrinsics/set_weights.py b/bittensor/extrinsics/set_weights.py index 61960ceb78..9ec62bf157 100644 --- a/bittensor/extrinsics/set_weights.py +++ b/bittensor/extrinsics/set_weights.py @@ -19,12 +19,14 @@ import bittensor import logging +import os import numpy as np from numpy.typing import NDArray from rich.prompt import Confirm from typing import Union, Tuple import bittensor.utils.weight_utils as weight_utils from bittensor.btlogging.defines import BITTENSOR_LOGGER_NAME +from bittensor.utils import torch logger = logging.getLogger(BITTENSOR_LOGGER_NAME) @@ -33,8 +35,8 @@ def set_weights_extrinsic( subtensor: "bittensor.subtensor", wallet: "bittensor.wallet", netuid: int, - uids: Union[NDArray[np.int64], list], - weights: Union[NDArray[np.float32], list], + uids: Union[NDArray[np.int64], "torch.LongTensor", list], + weights: Union[NDArray[np.float32], "torch.FloatTensor", list], version_key: int = 0, wait_for_inclusion: bool = False, wait_for_finalization: bool = False, @@ -49,9 +51,9 @@ def set_weights_extrinsic( Bittensor wallet object. netuid (int): The ``netuid`` of the subnet to set weights for. - uids (Union[NDArray[np.int64], list]): + uids (Union[NDArray[np.int64], torch.LongTensor, list]): The ``uint64`` uids of destination neurons. - weights (Union[NDArray[np.float32], list]): + weights (Union[NDArray[np.float32], torch.FloatTensor, list]): The weights to set. These must be ``float`` s and correspond to the passed ``uid`` s. version_key (int): The version key of the validator. @@ -65,12 +67,17 @@ def set_weights_extrinsic( success (bool): Flag is ``true`` if extrinsic was finalized or uncluded in the block. If we did not wait for finalization / inclusion, the response is ``true``. """ - # First convert types. - if isinstance(uids, list): - uids = np.array(uids, dtype=np.int64) - if isinstance(weights, list): - weights = np.array(weights, dtype=np.float32) + if os.getenv("USE_TORCH"): + if isinstance(uids, list): + uids = torch.tensor(uids, dtype=torch.int64) + if isinstance(weights, list): + weights = torch.tensor(weights, dtype=torch.float32) + else: + if isinstance(uids, list): + uids = np.array(uids, dtype=np.int64) + if isinstance(weights, list): + weights = np.array(weights, dtype=np.float32) # Reformat and normalize. weight_uids, weight_vals = weight_utils.convert_weights_and_uids_for_emit( diff --git a/bittensor/metagraph.py b/bittensor/metagraph.py index abb5462348..e7c1c11fff 100644 --- a/bittensor/metagraph.py +++ b/bittensor/metagraph.py @@ -17,6 +17,7 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +from abc import ABC, abstractmethod import os import pickle import numpy as np @@ -24,7 +25,7 @@ import bittensor from os import listdir from os.path import join -from typing import List, Optional +from typing import List, Optional, Union from bittensor.chain_data import AxonInfo from bittensor.utils.registration import torch @@ -49,6 +50,10 @@ ] +def use_torch() -> bool: + return bool(os.getenv("USE_TORCH") == "1") + + def get_save_dir(network: str, netuid: int) -> str: """ Return directory path from ``network`` and ``netuid``. @@ -92,7 +97,7 @@ def latest_block_path(dir_path: str) -> str: return latest_file_full_path -class Metagraph: +class MetagraphMixin(ABC): """ The metagraph class is a core component of the Bittensor network, representing the neural graph that forms the backbone of the decentralized machine learning system. @@ -150,7 +155,7 @@ class Metagraph: """ @property - def S(self) -> NDArray: + def S(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Represents the stake of each neuron in the Bittensor network. Stake is an important concept in the Bittensor ecosystem, signifying the amount of network weight (or “stake”) each neuron holds, @@ -163,7 +168,7 @@ def S(self) -> NDArray: return self.total_stake @property - def R(self) -> NDArray: + def R(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Contains the ranks of neurons in the Bittensor network. Ranks are determined by the network based on each neuron's performance and contributions. Higher ranks typically indicate a greater level of @@ -176,7 +181,7 @@ def R(self) -> NDArray: return self.ranks @property - def I(self) -> NDArray: + def I(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Incentive values of neurons represent the rewards they receive for their contributions to the network. The Bittensor network employs an incentive mechanism that rewards neurons based on their @@ -189,7 +194,7 @@ def I(self) -> NDArray: return self.incentive @property - def E(self) -> NDArray: + def E(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Denotes the emission values of neurons in the Bittensor network. Emissions refer to the distribution or release of rewards (often in the form of cryptocurrency) to neurons, typically based on their stake and @@ -202,7 +207,7 @@ def E(self) -> NDArray: return self.emission @property - def C(self) -> NDArray: + def C(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Represents the consensus values of neurons in the Bittensor network. Consensus is a measure of how much a neuron's contributions are trusted and agreed upon by the majority of the network. It is @@ -217,7 +222,7 @@ def C(self) -> NDArray: return self.consensus @property - def T(self) -> NDArray: + def T(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Represents the trust values assigned to each neuron in the Bittensor network. Trust is a key metric that reflects the reliability and reputation of a neuron based on its past behavior and contributions. It is @@ -233,7 +238,7 @@ def T(self) -> NDArray: return self.trust @property - def Tv(self) -> NDArray: + def Tv(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Contains the validator trust values of neurons in the Bittensor network. Validator trust is specifically associated with neurons that act as validators within the network. This specialized form of trust reflects @@ -249,7 +254,7 @@ def Tv(self) -> NDArray: return self.validator_trust @property - def D(self) -> NDArray: + def D(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Represents the dividends received by neurons in the Bittensor network. Dividends are a form of reward or distribution, typically given to neurons based on their stake, performance, and contribution to the network. @@ -261,7 +266,7 @@ def D(self) -> NDArray: return self.dividends @property - def B(self) -> NDArray: + def B(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Bonds in the Bittensor network represent a speculative reward mechanism where neurons can accumulate bonds in other neurons. Bonds are akin to investments or stakes in other neurons, reflecting a belief in @@ -274,7 +279,7 @@ def B(self) -> NDArray: return self.bonds @property - def W(self) -> NDArray: + def W(self) -> Union[NDArray, "torch.nn.Parameter"]: """ Represents the weights assigned to each neuron in the Bittensor network. In the context of Bittensor, weights are crucial for determining the influence and interaction between neurons. Each neuron is responsible @@ -342,6 +347,23 @@ def addresses(self) -> List[str]: """ return [axon.ip_str() for axon in self.axons] + @abstractmethod + def __init__(self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True): + """ + Initializes a new instance of the metagraph object, setting up the basic structure and parameters based on the provided arguments. + This method is the entry point for creating a metagraph object, + which is a central component in representing the state of the Bittensor network. + Args: + netuid (int): The unique identifier for the network, distinguishing this instance of the metagraph within potentially multiple network configurations. + network (str): The name of the network, which can indicate specific configurations or versions of the Bittensor network. + lite (bool): A flag indicating whether to use a lite version of the metagraph. The lite version may contain less detailed information but can be quicker to initialize and sync. + sync (bool): A flag indicating whether to synchronize the metagraph with the network upon initialization. Synchronization involves updating the metagraph's parameters to reflect the current state of the network. + Example: + Initializing a metagraph object for the Bittensor network with a specific network UID:: + metagraph = metagraph(netuid=123, network="finney", lite=True, sync=True) + """ + pass + def __str__(self) -> str: """ Provides a human-readable string representation of the metagraph object. This representation includes key identifiers and attributes of the metagraph, making it easier to quickly understand @@ -402,52 +424,6 @@ def metadata(self) -> dict: "version": bittensor.__version__, } - def __init__( - self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True - ): - """ - Initializes a new instance of the metagraph object, setting up the basic structure and parameters based on the provided arguments. - - This method is the entry point for creating a metagraph object, - which is a central component in representing the state of the Bittensor network. - - Args: - netuid (int): The unique identifier for the network, distinguishing this instance of the metagraph within potentially multiple network configurations. - network (str): The name of the network, which can indicate specific configurations or versions of the Bittensor network. - lite (bool): A flag indicating whether to use a lite version of the metagraph. The lite version may contain less detailed information but can be quicker to initialize and sync. - sync (bool): A flag indicating whether to synchronize the metagraph with the network upon initialization. Synchronization involves updating the metagraph's parameters to reflect the current state of the network. - - Example: - Initializing a metagraph object for the Bittensor network with a specific network UID:: - - metagraph = metagraph(netuid=123, network="finney", lite=True, sync=True) - """ - super(metagraph, self).__init__() - - self.netuid = netuid - self.network = network - self.version = (np.array([bittensor.__version_as_int__], dtype=np.int64),) - self.n = np.array([0], dtype=np.int64) - self.block = np.array([0], dtype=np.int64) - self.stake = np.array([], dtype=np.float32) - self.total_stake = np.array([], dtype=np.float32) - self.ranks = np.array([], dtype=np.float32) - self.trust = np.array([], dtype=np.float32) - self.consensus = np.array([], dtype=np.float32) - self.validator_trust = np.array([], dtype=np.float32) - self.incentive = np.array([], dtype=np.float32) - self.emission = np.array([], dtype=np.float32) - self.dividends = np.array([], dtype=np.float32) - self.active = np.array([], dtype=np.int64) - self.last_update = np.array([], dtype=np.int64) - self.validator_permit = np.array([], dtype=bool) - self.weights = np.array([], dtype=np.float32) - self.bonds = np.array([], dtype=np.int64) - self.uids = np.array([], dtype=np.int64) - self.axons: List[AxonInfo] = [] - if sync: - self.sync(block=None, lite=lite) - def state_dict(self): return { "netuid": self.netuid, @@ -583,71 +559,8 @@ def _assign_neurons(self, block, lite, subtensor): self.neurons = subtensor.neurons(block=block, netuid=self.netuid) self.lite = lite - def _set_metagraph_attributes(self, block, subtensor): - """ - Sets various attributes of the metagraph based on the latest network data fetched from the subtensor. - - This method updates parameters like the number of neurons, block number, stakes, trusts, ranks, and other neuron-specific information. - - Args: - block: The block number for which the metagraph attributes need to be set. If ``None``, the latest block data is used. - subtensor: The subtensor instance used for fetching the latest network data. - - Internal Usage: - Used internally during the sync process to update the metagraph's attributes:: - - self._set_metagraph_attributes(block, subtensor) - """ - # TODO: Check and test the setting of each attribute - self.n = self._create_tensor(len(self.neurons), dtype=np.int64) - self.version = self._create_tensor( - [bittensor.__version_as_int__], dtype=np.int64 - ) - self.block = self._create_tensor( - block if block else subtensor.block, dtype=np.int64 - ) - self.uids = self._create_tensor( - [neuron.uid for neuron in self.neurons], dtype=np.int64 - ) - self.trust = self._create_tensor( - [neuron.trust for neuron in self.neurons], dtype=np.float32 - ) - self.consensus = self._create_tensor( - [neuron.consensus for neuron in self.neurons], dtype=np.float32 - ) - self.incentive = self._create_tensor( - [neuron.incentive for neuron in self.neurons], dtype=np.float32 - ) - self.dividends = self._create_tensor( - [neuron.dividends for neuron in self.neurons], dtype=np.float32 - ) - self.ranks = self._create_tensor( - [neuron.rank for neuron in self.neurons], dtype=np.float32 - ) - self.emission = self._create_tensor( - [neuron.emission for neuron in self.neurons], dtype=np.float32 - ) - self.active = self._create_tensor( - [neuron.active for neuron in self.neurons], dtype=np.int64 - ) - self.last_update = self._create_tensor( - [neuron.last_update for neuron in self.neurons], dtype=np.int64 - ) - self.validator_permit = self._create_tensor( - [neuron.validator_permit for neuron in self.neurons], dtype=bool - ) - self.validator_trust = self._create_tensor( - [neuron.validator_trust for neuron in self.neurons], dtype=np.float32 - ) - self.total_stake = self._create_tensor( - [neuron.total_stake.tao for neuron in self.neurons], dtype=np.float32 - ) - self.stake = self._create_tensor( - [neuron.stake for neuron in self.neurons], dtype=np.float32 - ) - self.axons = [n.axon_info for n in self.neurons] - - def _create_tensor(self, data, dtype) -> NDArray: + @staticmethod + def _create_tensor(data, dtype) -> Union[NDArray, "torch.nn.Parameter"]: """ Creates a numpy array with the given data and data type. This method is a utility function used internally to encapsulate data into a np.array, making it compatible with the metagraph's numpy model structure. @@ -664,7 +577,7 @@ def _create_tensor(self, data, dtype) -> NDArray: self.stake = self._create_tensor(neuron_stakes, dtype=np.float32) """ # TODO: Check and test the creation of tensor - return np.array(data, dtype=dtype) + return torch.nn.Parameter(torch.tensor(data, dtype=dtype), requires_grad=False) if use_torch() else np.array(data, dtype=dtype) def _set_weights_and_bonds(self, subtensor: Optional[bittensor.subtensor] = None): """ @@ -691,7 +604,7 @@ def _set_weights_and_bonds(self, subtensor: Optional[bittensor.subtensor] = None [neuron.bonds for neuron in self.neurons], "bonds" ) - def _process_weights_or_bonds(self, data, attribute: str) -> NDArray: + def _process_weights_or_bonds(self, data, attribute: str) -> Union[NDArray, "torch.nn.Parameter"]: """ Processes the raw weights or bonds data and converts it into a structured tensor format. This method handles the transformation of neuron connection data (``weights`` or ``bonds``) from a list or other unstructured format into a tensor that can be utilized within the metagraph model. @@ -710,7 +623,10 @@ def _process_weights_or_bonds(self, data, attribute: str) -> NDArray: data_array = [] for item in data: if len(item) == 0: - data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) + if use_torch(): + data_array.append(torch.zeros(len(self.neurons))) + else: + data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) else: uids, values = zip(*item) # TODO: Validate and test the conversion of uids and values to tensor @@ -729,6 +645,10 @@ def _process_weights_or_bonds(self, data, attribute: str) -> NDArray: ) ) tensor_param = ( + torch.nn.Parameter(torch.stack(data_array), requires_grad=False) + if len(data_array) + else torch.nn.Parameter() + ) if use_torch() else ( np.stack(data_array) if len(data_array) else np.array([], dtype=np.float32) ) if len(data_array) == 0: @@ -737,9 +657,13 @@ def _process_weights_or_bonds(self, data, attribute: str) -> NDArray: ) return tensor_param + @abstractmethod + def _set_metagraph_attributes(self, block, subtensor): + pass + def _process_root_weights( self, data, attribute: str, subtensor: bittensor.subtensor - ) -> NDArray: + ) -> Union[NDArray, "torch.nn.Parameter"]: """ Specifically processes the root weights data for the metagraph. This method is similar to :func:`_process_weights_or_bonds` but is tailored for processing root weights, which have a different structure and significance in the network. @@ -764,7 +688,10 @@ def _process_root_weights( subnets = subtensor.get_subnets() for item in data: if len(item) == 0: - data_array.append(np.zeros(n_subnets, dtype=np.float32)) + if use_torch(): + data_array.append(torch.zeros(n_subnets)) + else: + data_array.append(np.zeros(n_subnets, dtype=np.float32)) else: uids, values = zip(*item) # TODO: Validate and test the conversion of uids and values to tensor @@ -775,6 +702,10 @@ def _process_root_weights( ) tensor_param = ( + torch.nn.Parameter(torch.stack(data_array), requires_grad=False) + if len(data_array) + else torch.nn.Parameter() + ) if use_torch() else ( np.stack(data_array) if len(data_array) else np.array([], dtype=np.float32) ) if len(data_array) == 0: @@ -807,10 +738,17 @@ def save(self) -> "metagraph": """ save_directory = get_save_dir(self.network, self.netuid) os.makedirs(save_directory, exist_ok=True) - graph_filename = save_directory + f"/block-{self.block.item()}.pt" - state_dict = self.state_dict() - with open(graph_filename, "wb") as graph_file: - pickle.dump(state_dict, graph_file) + if use_torch(): + graph_file = save_directory + f"/block-{self.block.item()}.pt" + state_dict = self.state_dict() + state_dict["axons"] = self.axons + torch.save(state_dict, graph_file) + state_dict = torch.load(graph_file) # verifies that the file can be loaded correctly + else: + graph_filename = save_directory + f"/block-{self.block.item()}.pt" + state_dict = self.state_dict() + with open(graph_filename, "wb") as graph_file: + pickle.dump(state_dict, graph_file) return self def load(self): @@ -838,6 +776,7 @@ def load(self): """ self.load_from_path(get_save_dir(self.network, self.netuid)) + @abstractmethod def load_from_path(self, dir_path: str) -> "metagraph": """ Loads the state of the metagraph from a specified directory path. This method is crucial for restoring the metagraph to a specific state based on saved data. It locates the latest block file in the given @@ -865,6 +804,294 @@ def load_from_path(self, dir_path: str) -> "metagraph": contain valid data for the metagraph. It is essential to ensure that the directory path and the state files within it are accurate and consistent with the expected metagraph structure. """ + pass + + +BaseClass = torch.nn.Module if use_torch() else object + + +class TorchMetaGraph(MetagraphMixin, BaseClass): + def __init__( + self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True + ): + """ + Initializes a new instance of the metagraph object, setting up the basic structure and parameters based on the provided arguments. + This method is the entry point for creating a metagraph object, + which is a central component in representing the state of the Bittensor network. + Args: + netuid (int): The unique identifier for the network, distinguishing this instance of the metagraph within potentially multiple network configurations. + network (str): The name of the network, which can indicate specific configurations or versions of the Bittensor network. + lite (bool): A flag indicating whether to use a lite version of the metagraph. The lite version may contain less detailed information but can be quicker to initialize and sync. + sync (bool): A flag indicating whether to synchronize the metagraph with the network upon initialization. Synchronization involves updating the metagraph's parameters to reflect the current state of the network. + Example: + Initializing a metagraph object for the Bittensor network with a specific network UID:: + metagraph = metagraph(netuid=123, network="finney", lite=True, sync=True) + """ + torch.nn.Module.__init__(self) + MetagraphMixin.__init__(self, netuid, network, lite, sync) + self.netuid = netuid + self.network = network + self.version = torch.nn.Parameter( + torch.tensor([bittensor.__version_as_int__], dtype=torch.int64), + requires_grad=False, + ) + self.n: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([0], dtype=torch.int64), requires_grad=False + ) + self.block: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([0], dtype=torch.int64), requires_grad=False + ) + self.stake = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.total_stake: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.ranks: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.trust: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.consensus: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.validator_trust: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.incentive: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.emission: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.dividends: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.active = torch.nn.Parameter( + torch.tensor([], dtype=torch.int64), requires_grad=False + ) + self.last_update = torch.nn.Parameter( + torch.tensor([], dtype=torch.int64), requires_grad=False + ) + self.validator_permit = torch.nn.Parameter( + torch.tensor([], dtype=torch.bool), requires_grad=False + ) + self.weights: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.float32), requires_grad=False + ) + self.bonds: torch.nn.Parameter = torch.nn.Parameter( + torch.tensor([], dtype=torch.int64), requires_grad=False + ) + self.uids = torch.nn.Parameter( + torch.tensor([], dtype=torch.int64), requires_grad=False + ) + self.axons: List[AxonInfo] = [] + if sync: + self.sync(block=None, lite=lite) + + def _set_metagraph_attributes(self, block, subtensor): + """ + Sets various attributes of the metagraph based on the latest network data fetched from the subtensor. + + This method updates parameters like the number of neurons, block number, stakes, trusts, ranks, and other neuron-specific information. + + Args: + block: The block number for which the metagraph attributes need to be set. If ``None``, the latest block data is used. + subtensor: The subtensor instance used for fetching the latest network data. + + Internal Usage: + Used internally during the sync process to update the metagraph's attributes:: + + self._set_metagraph_attributes(block, subtensor) + """ + self.n = self._create_tensor(len(self.neurons), dtype=torch.int64) + self.version = self._create_tensor( + [bittensor.__version_as_int__], dtype=torch.int64 + ) + self.block = self._create_tensor( + block if block else subtensor.block, dtype=torch.int64 + ) + self.uids = self._create_tensor( + [neuron.uid for neuron in self.neurons], dtype=torch.int64 + ) + self.trust = self._create_tensor( + [neuron.trust for neuron in self.neurons], dtype=torch.float32 + ) + self.consensus = self._create_tensor( + [neuron.consensus for neuron in self.neurons], dtype=torch.float32 + ) + self.incentive = self._create_tensor( + [neuron.incentive for neuron in self.neurons], dtype=torch.float32 + ) + self.dividends = self._create_tensor( + [neuron.dividends for neuron in self.neurons], dtype=torch.float32 + ) + self.ranks = self._create_tensor( + [neuron.rank for neuron in self.neurons], dtype=torch.float32 + ) + self.emission = self._create_tensor( + [neuron.emission for neuron in self.neurons], dtype=torch.float32 + ) + self.active = self._create_tensor( + [neuron.active for neuron in self.neurons], dtype=torch.int64 + ) + self.last_update = self._create_tensor( + [neuron.last_update for neuron in self.neurons], dtype=torch.int64 + ) + self.validator_permit = self._create_tensor( + [neuron.validator_permit for neuron in self.neurons], dtype=torch.bool + ) + self.validator_trust = self._create_tensor( + [neuron.validator_trust for neuron in self.neurons], dtype=torch.float32 + ) + self.total_stake = self._create_tensor( + [neuron.total_stake.tao for neuron in self.neurons], dtype=torch.float32 + ) + self.stake = self._create_tensor( + [neuron.stake for neuron in self.neurons], dtype=torch.float32 + ) + self.axons = [n.axon_info for n in self.neurons] + + def load_from_path(self, dir_path: str) -> "metagraph": + graph_file = latest_block_path(dir_path) + state_dict = torch.load(graph_file) + self.n = torch.nn.Parameter(state_dict["n"], requires_grad=False) + self.block = torch.nn.Parameter(state_dict["block"], requires_grad=False) + self.uids = torch.nn.Parameter(state_dict["uids"], requires_grad=False) + self.stake = torch.nn.Parameter(state_dict["stake"], requires_grad=False) + self.total_stake = torch.nn.Parameter( + state_dict["total_stake"], requires_grad=False + ) + self.ranks = torch.nn.Parameter(state_dict["ranks"], requires_grad=False) + self.trust = torch.nn.Parameter(state_dict["trust"], requires_grad=False) + self.consensus = torch.nn.Parameter( + state_dict["consensus"], requires_grad=False + ) + self.validator_trust = torch.nn.Parameter( + state_dict["validator_trust"], requires_grad=False + ) + self.incentive = torch.nn.Parameter( + state_dict["incentive"], requires_grad=False + ) + self.emission = torch.nn.Parameter(state_dict["emission"], requires_grad=False) + self.dividends = torch.nn.Parameter( + state_dict["dividends"], requires_grad=False + ) + self.active = torch.nn.Parameter(state_dict["active"], requires_grad=False) + self.last_update = torch.nn.Parameter( + state_dict["last_update"], requires_grad=False + ) + self.validator_permit = torch.nn.Parameter( + state_dict["validator_permit"], requires_grad=False + ) + self.uids = torch.nn.Parameter(state_dict["uids"], requires_grad=False) + self.axons = state_dict["axons"] + if "weights" in state_dict: + self.weights = torch.nn.Parameter( + state_dict["weights"], requires_grad=False + ) + if "bonds" in state_dict: + self.bonds = torch.nn.Parameter(state_dict["bonds"], requires_grad=False) + return self + + +class NonTorchMetagraph(MetagraphMixin): + def __init__( + self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True + ): + # super(metagraph, self).__init__() + MetagraphMixin.__init__(self, netuid, network, lite, sync) + + self.netuid = netuid + self.network = network + self.version = (np.array([bittensor.__version_as_int__], dtype=np.int64),) + self.n = np.array([0], dtype=np.int64) + self.block = np.array([0], dtype=np.int64) + self.stake = np.array([], dtype=np.float32) + self.total_stake = np.array([], dtype=np.float32) + self.ranks = np.array([], dtype=np.float32) + self.trust = np.array([], dtype=np.float32) + self.consensus = np.array([], dtype=np.float32) + self.validator_trust = np.array([], dtype=np.float32) + self.incentive = np.array([], dtype=np.float32) + self.emission = np.array([], dtype=np.float32) + self.dividends = np.array([], dtype=np.float32) + self.active = np.array([], dtype=np.int64) + self.last_update = np.array([], dtype=np.int64) + self.validator_permit = np.array([], dtype=bool) + self.weights = np.array([], dtype=np.float32) + self.bonds = np.array([], dtype=np.int64) + self.uids = np.array([], dtype=np.int64) + self.axons: List[AxonInfo] = [] + if sync: + self.sync(block=None, lite=lite) + + def _set_metagraph_attributes(self, block, subtensor): + """ + Sets various attributes of the metagraph based on the latest network data fetched from the subtensor. + + This method updates parameters like the number of neurons, block number, stakes, trusts, ranks, and other neuron-specific information. + + Args: + block: The block number for which the metagraph attributes need to be set. If ``None``, the latest block data is used. + subtensor: The subtensor instance used for fetching the latest network data. + + Internal Usage: + Used internally during the sync process to update the metagraph's attributes:: + + self._set_metagraph_attributes(block, subtensor) + """ + # TODO: Check and test the setting of each attribute + self.n = self._create_tensor(len(self.neurons), dtype=np.int64) + self.version = self._create_tensor( + [bittensor.__version_as_int__], dtype=np.int64 + ) + self.block = self._create_tensor( + block if block else subtensor.block, dtype=np.int64 + ) + self.uids = self._create_tensor( + [neuron.uid for neuron in self.neurons], dtype=np.int64 + ) + self.trust = self._create_tensor( + [neuron.trust for neuron in self.neurons], dtype=np.float32 + ) + self.consensus = self._create_tensor( + [neuron.consensus for neuron in self.neurons], dtype=np.float32 + ) + self.incentive = self._create_tensor( + [neuron.incentive for neuron in self.neurons], dtype=np.float32 + ) + self.dividends = self._create_tensor( + [neuron.dividends for neuron in self.neurons], dtype=np.float32 + ) + self.ranks = self._create_tensor( + [neuron.rank for neuron in self.neurons], dtype=np.float32 + ) + self.emission = self._create_tensor( + [neuron.emission for neuron in self.neurons], dtype=np.float32 + ) + self.active = self._create_tensor( + [neuron.active for neuron in self.neurons], dtype=np.int64 + ) + self.last_update = self._create_tensor( + [neuron.last_update for neuron in self.neurons], dtype=np.int64 + ) + self.validator_permit = self._create_tensor( + [neuron.validator_permit for neuron in self.neurons], dtype=bool + ) + self.validator_trust = self._create_tensor( + [neuron.validator_trust for neuron in self.neurons], dtype=np.float32 + ) + self.total_stake = self._create_tensor( + [neuron.total_stake.tao for neuron in self.neurons], dtype=np.float32 + ) + self.stake = self._create_tensor( + [neuron.stake for neuron in self.neurons], dtype=np.float32 + ) + self.axons = [n.axon_info for n in self.neurons] + + def load_from_path(self, dir_path: str) -> "metagraph": graph_filename = latest_block_path(dir_path) try: with open(graph_filename, "rb") as graph_file: @@ -907,5 +1134,5 @@ def load_from_path(self, dir_path: str) -> "metagraph": self.bonds = state_dict["bonds"] return self - -metagraph = Metagraph +print("USE_TORCH", use_torch()) +metagraph = TorchMetaGraph if use_torch() else NonTorchMetagraph diff --git a/tests/integration_tests/test_metagraph_integration.py b/tests/integration_tests/test_metagraph_integration.py index b2231e3e92..c7396a5a9c 100644 --- a/tests/integration_tests/test_metagraph_integration.py +++ b/tests/integration_tests/test_metagraph_integration.py @@ -58,7 +58,7 @@ def test_load_sync_save(self): self.metagraph.save() def test_load_sync_save_from_torch(self): - os.environ["USE_TORCH"] = "1" + # os.environ["USE_TORCH"] = "1" self.metagraph.sync(lite=True, subtensor=self.sub) def deprecated_save_torch(metagraph): @@ -74,7 +74,7 @@ def deprecated_save_torch(metagraph): deprecated_save_torch(self.metagraph) self.metagraph.load() - del os.environ["USE_TORCH"] + # del os.environ["USE_TORCH"] def test_state_dict(self): self.metagraph.load() From e59b78212fceaf849c6ce524a63fe6c23c9baa6d Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 19:28:54 +0200 Subject: [PATCH 12/18] Fixed tests and mypy and black. --- bittensor/dendrite.py | 7 ++- bittensor/metagraph.py | 120 ++++++++++++++++++++++++++++------------- bittensor/subtensor.py | 4 +- 3 files changed, 91 insertions(+), 40 deletions(-) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 4c31553c48..4354de3ff0 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -817,14 +817,17 @@ def __del__(self): BaseModel = torch.nn.Module if use_torch() else object -class dendrite(DendriteMixin, BaseModel): - def __init__(self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None): +class dendrite(DendriteMixin, BaseModel): # type: ignore + def __init__( + self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None + ): if use_torch(): torch.nn.Module.__init__(self) DendriteMixin.__init__(self, wallet) if not use_torch(): + async def call(self, *args, **kwargs): return await self.forward(*args, **kwargs) diff --git a/bittensor/metagraph.py b/bittensor/metagraph.py index e7c1c11fff..5a77a1b57e 100644 --- a/bittensor/metagraph.py +++ b/bittensor/metagraph.py @@ -25,10 +25,10 @@ import bittensor from os import listdir from os.path import join -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple, Any from bittensor.chain_data import AxonInfo -from bittensor.utils.registration import torch +from bittensor.utils.registration import torch, Torch METAGRAPH_STATE_DICT_NDARRAY_KEYS = [ "version", @@ -154,6 +154,28 @@ class MetagraphMixin(ABC): hotkeys = deepcopy(metagraph.hotkeys) """ + netuid: int + network: str + version: Union["torch.nn.Parameter", Tuple[NDArray]] + n: Union["torch.nn.Parameter", NDArray] + block: Union["torch.nn.Parameter", NDArray] + stake: Union["torch.nn.Parameter", NDArray] + total_stake: Union["torch.nn.Parameter", NDArray] + ranks: Union["torch.nn.Parameter", NDArray] + trust: Union["torch.nn.Parameter", NDArray] + consensus: Union["torch.nn.Parameter", NDArray] + validator_trust: Union["torch.nn.Parameter", NDArray] + incentive: Union["torch.nn.Parameter", NDArray] + emission: Union["torch.nn.Parameter", NDArray] + dividends: Union["torch.nn.Parameter", NDArray] + active: Union["torch.nn.Parameter", NDArray] + last_update: Union["torch.nn.Parameter", NDArray] + validator_permit: Union["torch.nn.Parameter", NDArray] + weights: Union["torch.nn.Parameter", NDArray] + bonds: Union["torch.nn.Parameter", NDArray] + uids: Union["torch.nn.Parameter", NDArray] + axons: List[AxonInfo] + @property def S(self) -> Union[NDArray, "torch.nn.Parameter"]: """ @@ -348,7 +370,9 @@ def addresses(self) -> List[str]: return [axon.ip_str() for axon in self.axons] @abstractmethod - def __init__(self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True): + def __init__( + self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True + ): """ Initializes a new instance of the metagraph object, setting up the basic structure and parameters based on the provided arguments. This method is the entry point for creating a metagraph object, @@ -577,7 +601,11 @@ def _create_tensor(data, dtype) -> Union[NDArray, "torch.nn.Parameter"]: self.stake = self._create_tensor(neuron_stakes, dtype=np.float32) """ # TODO: Check and test the creation of tensor - return torch.nn.Parameter(torch.tensor(data, dtype=dtype), requires_grad=False) if use_torch() else np.array(data, dtype=dtype) + return ( + torch.nn.Parameter(torch.tensor(data, dtype=dtype), requires_grad=False) + if use_torch() + else np.array(data, dtype=dtype) + ) def _set_weights_and_bonds(self, subtensor: Optional[bittensor.subtensor] = None): """ @@ -604,7 +632,9 @@ def _set_weights_and_bonds(self, subtensor: Optional[bittensor.subtensor] = None [neuron.bonds for neuron in self.neurons], "bonds" ) - def _process_weights_or_bonds(self, data, attribute: str) -> Union[NDArray, "torch.nn.Parameter"]: + def _process_weights_or_bonds( + self, data, attribute: str + ) -> Union[NDArray, "torch.nn.Parameter"]: """ Processes the raw weights or bonds data and converts it into a structured tensor format. This method handles the transformation of neuron connection data (``weights`` or ``bonds``) from a list or other unstructured format into a tensor that can be utilized within the metagraph model. @@ -624,32 +654,38 @@ def _process_weights_or_bonds(self, data, attribute: str) -> Union[NDArray, "tor for item in data: if len(item) == 0: if use_torch(): - data_array.append(torch.zeros(len(self.neurons))) + data_array.append(torch.zeros(len(self.neurons))) # type: ignore else: - data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) + data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) # type: ignore else: uids, values = zip(*item) # TODO: Validate and test the conversion of uids and values to tensor if attribute == "weights": data_array.append( bittensor.utils.weight_utils.convert_weight_uids_and_vals_to_tensor( - len(self.neurons), list(uids), list(values) + len(self.neurons), list(uids), list(values) # type: ignore ) ) else: data_array.append( - bittensor.utils.weight_utils.convert_bond_uids_and_vals_to_tensor( + bittensor.utils.weight_utils.convert_bond_uids_and_vals_to_tensor( # type: ignore len(self.neurons), list(uids), list(values) ).astype( np.float32 ) ) - tensor_param = ( - torch.nn.Parameter(torch.stack(data_array), requires_grad=False) - if len(data_array) - else torch.nn.Parameter() - ) if use_torch() else ( - np.stack(data_array) if len(data_array) else np.array([], dtype=np.float32) + tensor_param: Union["torch.nn.Parameter", NDArray] = ( + ( + torch.nn.Parameter(torch.stack(data_array), requires_grad=False) + if len(data_array) + else torch.nn.Parameter() + ) + if use_torch() + else ( + np.stack(data_array) + if len(data_array) + else np.array([], dtype=np.float32) + ) ) if len(data_array) == 0: bittensor.logging.warning( @@ -691,22 +727,28 @@ def _process_root_weights( if use_torch(): data_array.append(torch.zeros(n_subnets)) else: - data_array.append(np.zeros(n_subnets, dtype=np.float32)) + data_array.append(np.zeros(n_subnets, dtype=np.float32)) # type: ignore else: uids, values = zip(*item) # TODO: Validate and test the conversion of uids and values to tensor data_array.append( - bittensor.utils.weight_utils.convert_root_weight_uids_and_vals_to_tensor( + bittensor.utils.weight_utils.convert_root_weight_uids_and_vals_to_tensor( # type: ignore n_subnets, list(uids), list(values), subnets ) ) - tensor_param = ( - torch.nn.Parameter(torch.stack(data_array), requires_grad=False) - if len(data_array) - else torch.nn.Parameter() - ) if use_torch() else ( - np.stack(data_array) if len(data_array) else np.array([], dtype=np.float32) + tensor_param: Union[NDArray, "torch.nn.Parameter"] = ( + ( + torch.nn.Parameter(torch.stack(data_array), requires_grad=False) + if len(data_array) + else torch.nn.Parameter() + ) + if use_torch() + else ( + np.stack(data_array) + if len(data_array) + else np.array([], dtype=np.float32) + ) ) if len(data_array) == 0: bittensor.logging.warning( @@ -714,7 +756,7 @@ def _process_root_weights( ) return tensor_param - def save(self) -> "metagraph": + def save(self) -> "metagraph": # type: ignore """ Saves the current state of the metagraph to a file on disk. This function is crucial for persisting the current state of the network's metagraph, which can later be reloaded or analyzed. The save operation includes all neuron attributes and parameters, ensuring a complete snapshot of the metagraph's state. @@ -739,13 +781,15 @@ def save(self) -> "metagraph": save_directory = get_save_dir(self.network, self.netuid) os.makedirs(save_directory, exist_ok=True) if use_torch(): - graph_file = save_directory + f"/block-{self.block.item()}.pt" + graph_filename = f"{save_directory}/block-{self.block.item()}.pt" state_dict = self.state_dict() state_dict["axons"] = self.axons - torch.save(state_dict, graph_file) - state_dict = torch.load(graph_file) # verifies that the file can be loaded correctly + torch.save(state_dict, graph_filename) + state_dict = torch.load( + graph_filename + ) # verifies that the file can be loaded correctly else: - graph_filename = save_directory + f"/block-{self.block.item()}.pt" + graph_filename = f"{save_directory}/block-{self.block.item()}.pt" state_dict = self.state_dict() with open(graph_filename, "wb") as graph_file: pickle.dump(state_dict, graph_file) @@ -777,7 +821,7 @@ def load(self): self.load_from_path(get_save_dir(self.network, self.netuid)) @abstractmethod - def load_from_path(self, dir_path: str) -> "metagraph": + def load_from_path(self, dir_path: str) -> "metagraph": # type: ignore """ Loads the state of the metagraph from a specified directory path. This method is crucial for restoring the metagraph to a specific state based on saved data. It locates the latest block file in the given directory and loads all metagraph parameters from it. This is particularly useful for analyses that require historical states of the network or for restoring previous states of the metagraph in different @@ -810,9 +854,9 @@ def load_from_path(self, dir_path: str) -> "metagraph": BaseClass = torch.nn.Module if use_torch() else object -class TorchMetaGraph(MetagraphMixin, BaseClass): +class TorchMetaGraph(MetagraphMixin, BaseClass): # type: ignore def __init__( - self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True + self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True ): """ Initializes a new instance of the metagraph object, setting up the basic structure and parameters based on the provided arguments. @@ -953,7 +997,7 @@ def _set_metagraph_attributes(self, block, subtensor): ) self.axons = [n.axon_info for n in self.neurons] - def load_from_path(self, dir_path: str) -> "metagraph": + def load_from_path(self, dir_path: str) -> "metagraph": # type: ignore graph_file = latest_block_path(dir_path) state_dict = torch.load(graph_file) self.n = torch.nn.Parameter(state_dict["n"], requires_grad=False) @@ -998,7 +1042,7 @@ def load_from_path(self, dir_path: str) -> "metagraph": class NonTorchMetagraph(MetagraphMixin): def __init__( - self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True + self, netuid: int, network: str = "finney", lite: bool = True, sync: bool = True ): # super(metagraph, self).__init__() MetagraphMixin.__init__(self, netuid, network, lite, sync) @@ -1091,7 +1135,7 @@ def _set_metagraph_attributes(self, block, subtensor): ) self.axons = [n.axon_info for n in self.neurons] - def load_from_path(self, dir_path: str) -> "metagraph": + def load_from_path(self, dir_path: str) -> "metagraph": # type: ignore graph_filename = latest_block_path(dir_path) try: with open(graph_filename, "rb") as graph_file: @@ -1105,10 +1149,13 @@ def load_from_path(self, dir_path: str) -> "metagraph": "metagraph state from legacy saves, but will not be supported in the future." ) try: - state_dict = torch.load(graph_filename) + import torch as real_torch + + state_dict = real_torch.load(graph_filename) for key in METAGRAPH_STATE_DICT_NDARRAY_KEYS: state_dict[key] = state_dict[key].detach().numpy() - except RuntimeError: + del real_torch + except (RuntimeError, ImportError): bittensor.__console__.print("Unable to load file. It may be corrupted.") raise @@ -1134,5 +1181,6 @@ def load_from_path(self, dir_path: str) -> "metagraph": self.bonds = state_dict["bonds"] return self + print("USE_TORCH", use_torch()) metagraph = TorchMetaGraph if use_torch() else NonTorchMetagraph diff --git a/bittensor/subtensor.py b/bittensor/subtensor.py index 7df6d7223f..b6e63af3e4 100644 --- a/bittensor/subtensor.py +++ b/bittensor/subtensor.py @@ -2303,7 +2303,7 @@ def commit(self, wallet, netuid: int, data: str): def get_commitment(self, netuid: int, uid: int, block: Optional[int] = None) -> str: metagraph = self.metagraph(netuid) - hotkey = metagraph.hotkeys[uid] + hotkey = metagraph.hotkeys[uid] # type: ignore metadata = get_metadata(self, netuid, hotkey, block) commitment = metadata["info"]["fields"][0] # type: ignore @@ -4110,7 +4110,7 @@ def metagraph( netuid: int, lite: bool = True, block: Optional[int] = None, - ) -> "bittensor.metagraph": + ) -> "bittensor.metagraph": # type: ignore """ Returns a synced metagraph for a specified subnet within the Bittensor network. The metagraph represents the network's structure, including neuron connections and interactions. From dbc570117057bc4a2118a5f36c5e5d02e00375a0 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 19:45:30 +0200 Subject: [PATCH 13/18] Flake tests --- bittensor/metagraph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bittensor/metagraph.py b/bittensor/metagraph.py index 5a77a1b57e..940bc6bab2 100644 --- a/bittensor/metagraph.py +++ b/bittensor/metagraph.py @@ -25,10 +25,10 @@ import bittensor from os import listdir from os.path import join -from typing import List, Optional, Union, Tuple, Any +from typing import List, Optional, Union, Tuple from bittensor.chain_data import AxonInfo -from bittensor.utils.registration import torch, Torch +from bittensor.utils.registration import torch METAGRAPH_STATE_DICT_NDARRAY_KEYS = [ "version", From 27f1048be40551a19812b0c83088b7d2a6f51547 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 21:17:07 +0200 Subject: [PATCH 14/18] Roman-requested changes, as well as changing use_torch to a general bool function rather than selecting individually from the `os.environ` every time. --- bittensor/chain_data.py | 19 +++++----- bittensor/commands/root.py | 13 ++++--- bittensor/dendrite.py | 11 ++---- bittensor/extrinsics/registration.py | 6 +-- bittensor/extrinsics/root.py | 11 +++--- bittensor/extrinsics/set_weights.py | 4 +- bittensor/metagraph.py | 9 +---- bittensor/tensor.py | 16 ++++---- bittensor/utils/__init__.py | 37 +++++++++++++++++-- bittensor/utils/registration.py | 10 +++-- bittensor/utils/weight_utils.py | 6 +-- .../test_metagraph_integration.py | 2 - 12 files changed, 82 insertions(+), 62 deletions(-) diff --git a/bittensor/chain_data.py b/bittensor/chain_data.py index 9b81d87803..11fed3ec67 100644 --- a/bittensor/chain_data.py +++ b/bittensor/chain_data.py @@ -25,8 +25,9 @@ from scalecodec.type_registry import load_type_registry_preset from scalecodec.utils.ss58 import ss58_encode -from .utils import networking as net, U16_MAX, U16_NORMALIZED_FLOAT, torch +from .utils import networking as net, U16_MAX, U16_NORMALIZED_FLOAT from .utils.balance import Balance +from .utils.registration import torch, use_torch custom_rpc_type_registry = { "types": { @@ -276,7 +277,7 @@ def to_parameter_dict( self, ) -> Union[dict[str, Union[int, str]], "torch.nn.ParameterDict"]: """Returns a torch tensor or dict of the subnet info, depending on the USE_TORCH flag set""" - if os.environ.get("USE_TORCH"): + if use_torch(): return self._to_parameter_dict("torch") else: return self._to_parameter_dict("numpy") @@ -297,7 +298,7 @@ def from_parameter_dict( cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] ) -> "AxonInfo": """Returns an axon_info object from a torch parameter_dict or a parameter dict.""" - if os.environ.get("USE_TORCH"): + if use_torch(): return cls._from_parameter_dict(parameter_dict, "torch") else: return cls._from_parameter_dict(parameter_dict, "numpy") @@ -1019,7 +1020,7 @@ def _to_parameter_dict( def to_parameter_dict(self) -> Union[dict[str, Any], "torch.nn.ParameterDict"]: """Returns a torch tensor or dict of the subnet info.""" - if os.environ.get("USE_TORCH"): + if use_torch(): return self._to_parameter_dict("torch") else: return self._to_parameter_dict("numpy") @@ -1040,7 +1041,7 @@ def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "SubnetIn def from_parameter_dict( cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] ) -> "SubnetInfo": - if os.environ.get("USE_TORCH"): + if use_torch(): return cls._from_parameter_dict_torch(parameter_dict) else: return cls._from_parameter_dict_numpy(parameter_dict) @@ -1142,7 +1143,7 @@ def to_parameter_dict( self, ) -> Union[dict[str, Union[int, float, bool]], "torch.nn.ParameterDict"]: """Returns a torch tensor or dict of the subnet hyperparameters.""" - if os.environ.get("USE_TORCH"): + if use_torch(): return self._to_parameter_dict_torch("torch") else: return self._to_parameter_dict_torch("numpy") @@ -1165,7 +1166,7 @@ def _from_parameter_dict_numpy( def from_parameter_dict( cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] ) -> "SubnetHyperparameters": - if os.environ.get("USE_TORCH"): + if use_torch(): return cls._from_parameter_dict_torch(parameter_dict) else: return cls._from_parameter_dict_numpy(parameter_dict) @@ -1237,7 +1238,7 @@ def to_parameter_dict( self, ) -> Union[dict[str, Union[str, int]], "torch.nn.ParameterDict"]: """Returns a torch tensor or dict of the subnet IP info.""" - if os.environ.get("USE_TORCH"): + if use_torch(): return self._to_parameter_dict("torch") else: return self._to_parameter_dict("numpy") @@ -1258,7 +1259,7 @@ def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "IPInfo": def from_parameter_dict( cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"] ) -> "IPInfo": - if os.environ.get("USE_TORCH"): + if use_torch(): return cls._from_parameter_dict_torch(parameter_dict) else: return cls._from_parameter_dict_numpy(parameter_dict) diff --git a/bittensor/commands/root.py b/bittensor/commands/root.py index 003d22fe4f..bd819179d5 100644 --- a/bittensor/commands/root.py +++ b/bittensor/commands/root.py @@ -24,7 +24,8 @@ from typing import List, Optional, Dict from rich.prompt import Prompt from rich.table import Table -from .utils import get_delegates_details, DelegatesDetails, torch +from .utils import get_delegates_details, DelegatesDetails +from bittensor.utils.registration import torch, use_torch from . import defaults @@ -303,7 +304,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): my_weights[cli.config.netuid] = new_weight all_netuids = ( torch.tensor(list(range(len(my_weights)))) - if os.environ.get("USE_TORCH") + if use_torch() else np.arange(len(my_weights)) ) @@ -425,7 +426,7 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): my_weights[my_weights < 0] = 0 # Ensure weights don't go negative all_netuids = ( torch.tensor(list(range(len(my_weights)))) - if os.environ.get("USE_TORCH") + if use_torch() else np.arange(len(my_weights)) ) @@ -531,14 +532,14 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): matched_netuids = list(map(int, re.split(r"[ ,]+", cli.config.netuids))) netuids = ( torch.tensor(matched_netuids, dtype=torch.long) - if os.environ.get("USE_TORCH") + if use_torch() else np.array(matched_netuids, dtype=np.int64) ) - matched_weights = list(map(float, re.split(r"[ ,]+", cli.config.weights))) + matched_weights = [float(weight) for weight in re.split(r"[ ,]+", cli.config.weights)] weights = ( torch.tensor(matched_weights, dtype=torch.float32) - if os.environ.get("USE_TORCH") + if use_torch() else np.array( matched_weights, dtype=np.float32, diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index 4354de3ff0..0c28cbff4b 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -20,17 +20,12 @@ from __future__ import annotations import asyncio -import os import uuid import time import aiohttp import bittensor -from typing import Union, Optional, List, Union, AsyncGenerator, Any -from .utils import torch - - -def use_torch() -> bool: - return True if os.getenv("USE_TORCH") == "1" else False +from typing import Optional, List, Union, AsyncGenerator, Any +from bittensor.utils.registration import torch, use_torch class DendriteMixin: @@ -814,7 +809,7 @@ def __del__(self): # For back-compatibility with torch -BaseModel = torch.nn.Module if use_torch() else object +BaseModel: Union["torch.nn.Module", object] = torch.nn.Module if use_torch() else object class dendrite(DendriteMixin, BaseModel): # type: ignore diff --git a/bittensor/extrinsics/registration.py b/bittensor/extrinsics/registration.py index 5504d8889f..fd2fe52675 100644 --- a/bittensor/extrinsics/registration.py +++ b/bittensor/extrinsics/registration.py @@ -21,7 +21,7 @@ import time from rich.prompt import Confirm from typing import List, Union, Optional, Tuple -from bittensor.utils.registration import POWSolution, create_pow, torch +from bittensor.utils.registration import POWSolution, create_pow, torch, use_torch def register_extrinsic( @@ -101,7 +101,7 @@ def register_extrinsic( ): return False - if not os.getenv("USE_TORCH"): + if not use_torch(): return False # Attempt rolling registration. @@ -381,7 +381,7 @@ def run_faucet_extrinsic( ): return False - if not os.getenv("USE_TORCH"): + if not use_torch(): return False # Unlock coldkey diff --git a/bittensor/extrinsics/root.py b/bittensor/extrinsics/root.py index 916dde3700..0254ffa523 100644 --- a/bittensor/extrinsics/root.py +++ b/bittensor/extrinsics/root.py @@ -1,7 +1,6 @@ # The MIT License (MIT) # Copyright © 2021 Yuma Rao # Copyright © 2023 Opentensor Foundation -import os # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation @@ -27,7 +26,7 @@ from typing import Union, List import bittensor.utils.weight_utils as weight_utils from bittensor.btlogging.defines import BITTENSOR_LOGGER_NAME -from bittensor.utils import torch +from bittensor.utils.registration import torch, use_torch logger = logging.getLogger(BITTENSOR_LOGGER_NAME) @@ -136,13 +135,13 @@ def set_root_weights_extrinsic( if isinstance(netuids, list): netuids = ( torch.tensor(netuids, dtype=torch.int64) - if os.getenv("USE_TORCH") + if use_torch() else np.array(netuids, dtype=np.int64) ) if isinstance(weights, list): weights = ( torch.tensor(weights, dtype=torch.float32) - if os.getenv("USE_TORCH") + if use_torch() else np.array(weights, dtype=np.float32) ) @@ -153,12 +152,12 @@ def set_root_weights_extrinsic( # Get non zero values. non_zero_weight_idx = ( torch.argwhere(weights > 0).squeeze(dim=1) - if os.getenv("USE_TORCH") + if use_torch() else np.argwhere(weights > 0).squeeze(axis=1) ) non_zero_weights = weights[non_zero_weight_idx] non_zero_weights_size = ( - non_zero_weights.numel() if os.getenv("USE_TORCH") else non_zero_weights.size + non_zero_weights.numel() if use_torch() else non_zero_weights.size ) if non_zero_weights_size < min_allowed_weights: raise ValueError( diff --git a/bittensor/extrinsics/set_weights.py b/bittensor/extrinsics/set_weights.py index e0fd0944db..928ae85b3e 100644 --- a/bittensor/extrinsics/set_weights.py +++ b/bittensor/extrinsics/set_weights.py @@ -26,7 +26,7 @@ from typing import Union, Tuple import bittensor.utils.weight_utils as weight_utils from bittensor.btlogging.defines import BITTENSOR_LOGGER_NAME -from bittensor.utils import torch +from bittensor.utils.registration import torch, use_torch logger = logging.getLogger(BITTENSOR_LOGGER_NAME) @@ -68,7 +68,7 @@ def set_weights_extrinsic( Flag is ``true`` if extrinsic was finalized or uncluded in the block. If we did not wait for finalization / inclusion, the response is ``true``. """ # First convert types. - if os.getenv("USE_TORCH"): + if use_torch(): if isinstance(uids, list): uids = torch.tensor(uids, dtype=torch.int64) if isinstance(weights, list): diff --git a/bittensor/metagraph.py b/bittensor/metagraph.py index 940bc6bab2..8bad4d6c78 100644 --- a/bittensor/metagraph.py +++ b/bittensor/metagraph.py @@ -28,7 +28,7 @@ from typing import List, Optional, Union, Tuple from bittensor.chain_data import AxonInfo -from bittensor.utils.registration import torch +from bittensor.utils.registration import torch, use_torch METAGRAPH_STATE_DICT_NDARRAY_KEYS = [ "version", @@ -50,10 +50,6 @@ ] -def use_torch() -> bool: - return bool(os.getenv("USE_TORCH") == "1") - - def get_save_dir(network: str, netuid: int) -> str: """ Return directory path from ``network`` and ``netuid``. @@ -851,7 +847,7 @@ def load_from_path(self, dir_path: str) -> "metagraph": # type: ignore pass -BaseClass = torch.nn.Module if use_torch() else object +BaseClass: Union["torch.nn.Module", object] = torch.nn.Module if use_torch() else object class TorchMetaGraph(MetagraphMixin, BaseClass): # type: ignore @@ -1182,5 +1178,4 @@ def load_from_path(self, dir_path: str) -> "metagraph": # type: ignore return self -print("USE_TORCH", use_torch()) metagraph = TorchMetaGraph if use_torch() else NonTorchMetagraph diff --git a/bittensor/tensor.py b/bittensor/tensor.py index 3b4c090845..f6432e82ce 100644 --- a/bittensor/tensor.py +++ b/bittensor/tensor.py @@ -23,7 +23,7 @@ import pydantic import msgpack_numpy from typing import Optional, Union, List -from bittensor.utils import torch +from bittensor.utils.registration import torch, use_torch NUMPY_DTYPES = { "float16": np.float16, @@ -37,7 +37,7 @@ "bool": bool, } -if os.getenv("USE_TORCH"): +if use_torch(): TORCH_DTYPES = { "torch.float16": torch.float16, "torch.float32": torch.float32, @@ -70,11 +70,11 @@ def cast_dtype(raw: Union[None, np.dtype, "torch.dtype", str]) -> str: return None if isinstance(raw, np.dtype): return NUMPY_DTYPES[raw] - elif os.getenv("USE_TORCH"): + elif use_torch(): if isinstance(raw, torch.dtype): return TORCH_DTYPES[raw] elif isinstance(raw, str): - if os.getenv("USE_TORCH"): + if use_torch(): assert ( raw in TORCH_DTYPES ), f"{raw} not a valid torch type in dict {TORCH_DTYPES}" @@ -125,7 +125,7 @@ class tensor: def __new__(cls, tensor: Union[list, np.ndarray, "torch.Tensor"]): if isinstance(tensor, list) or isinstance(tensor, np.ndarray): tensor = ( - torch.tensor(tensor) if os.getenv("USE_TORCH") else np.array(tensor) + torch.tensor(tensor) if use_torch() else np.array(tensor) ) return Tensor.serialize(tensor=tensor) @@ -152,7 +152,7 @@ def tolist(self) -> List[object]: def numpy(self) -> "numpy.ndarray": return ( self.deserialize().detach().numpy() - if os.getenv("USE_TORCH") + if use_torch() else self.deserialize() ) @@ -171,7 +171,7 @@ def deserialize(self) -> Union["np.ndarray", "torch.Tensor"]: numpy_object = msgpack.unpackb( buffer_bytes, object_hook=msgpack_numpy.decode ).copy() - if os.getenv("USE_TORCH"): + if use_torch(): torch_object = torch.as_tensor(numpy_object) # Reshape does not work for (0) or [0] if not (len(shape) == 1 and shape[0] == 0): @@ -201,7 +201,7 @@ def serialize(tensor: Union["np.ndarray", "torch.Tensor"]) -> "Tensor": shape = list(tensor.shape) if len(shape) == 0: shape = [0] - if os.getenv("USE_TORCH"): + if use_torch(): torch_numpy = tensor.cpu().detach().numpy().copy() data_buffer = base64.b64encode( msgpack.packb(torch_numpy, default=msgpack_numpy.encode) diff --git a/bittensor/utils/__init__.py b/bittensor/utils/__init__.py index 09cf9828bc..72d053ea7a 100644 --- a/bittensor/utils/__init__.py +++ b/bittensor/utils/__init__.py @@ -27,7 +27,7 @@ from .wallet_utils import * # noqa F401 from .version import version_checking, check_version, VersionCheckError -from .registration import torch +from .registration import torch, use_torch RAOPERTAO = 1e9 U16_MAX = 65535 @@ -49,6 +49,29 @@ def _unbiased_topk( axis=0, return_type: str = "numpy", ) -> Union[Tuple[np.ndarray, np.ndarray], Tuple["torch.Tensor", "torch.LongTensor"]]: + """Selects topk as in torch.topk but does not bias lower indices when values are equal. + Args: + values: (np.ndarray) if using numpy, (torch.Tensor) if using torch: + Values to index into. + k: (int): + Number to take. + dim: (int): + Dimension to index into (used by Torch) + sorted: (bool): + Whether to sort indices. + largest: (bool): + Whether to take the largest value. + axis: (int): + Axis along which to index into (used by Numpy) + return_type: (str): + Whether or use torch or numpy approach + + Return: + topk: (np.ndarray) if using numpy, (torch.Tensor) if using torch: + topk k values. + indices: (np.ndarray) if using numpy, (torch.LongTensor) if using torch: + indices of the topk values. + """ if return_type == "torch": permutation = torch.randperm(values.shape[dim]) permuted_values = values[permutation] @@ -80,12 +103,20 @@ def unbiased_topk( largest: bool = True, axis: int = 0, ) -> Union[Tuple[np.ndarray, np.ndarray], Tuple["torch.Tensor", "torch.LongTensor"]]: - r"""Selects topk as in torch.topk but does not bias lower indices when values are equal. + """Selects topk as in torch.topk but does not bias lower indices when values are equal. Args: values: (np.ndarray) if using numpy, (torch.Tensor) if using torch: Values to index into. k: (int): Number to take. + dim: (int): + Dimension to index into (used by Torch) + sorted: (bool): + Whether to sort indices. + largest: (bool): + Whether to take the largest value. + axis: (int): + Axis along which to index into (used by Numpy) Return: topk: (np.ndarray) if using numpy, (torch.Tensor) if using torch: @@ -93,7 +124,7 @@ def unbiased_topk( indices: (np.ndarray) if using numpy, (torch.LongTensor) if using torch: indices of the topk values. """ - if os.getenv("USE_TORCH"): + if use_torch(): return _unbiased_topk( values, k, dim, sorted, largest, axis, return_type="torch" ) diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py index 69f38f44db..8ff155dd2b 100644 --- a/bittensor/utils/registration.py +++ b/bittensor/utils/registration.py @@ -25,6 +25,10 @@ torch = None +def use_torch() -> bool: + return True if os.getenv("USE_TORCH") == "1" else False + + class Torch: def __init__(self): self._transformed = False @@ -49,7 +53,7 @@ def __bool__(self): return False def __getattr__(self, name): - if not self._transformed and os.getenv("USE_TORCH"): + if not self._transformed and use_torch(): self._transform() if self._transformed: return getattr(self, name) @@ -57,7 +61,7 @@ def __getattr__(self, name): self._error() def __call__(self, *args, **kwargs): - if not self._transformed and os.getenv("USE_TORCH"): + if not self._transformed and use_torch(): self._transform() if self._transformed: return self(*args, **kwargs) @@ -65,7 +69,7 @@ def __call__(self, *args, **kwargs): self._error() -if not torch or not os.getenv("USE_TORCH"): +if not torch or not use_torch(): torch = Torch() diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 3efe2310dc..8499a950ff 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -23,16 +23,12 @@ import bittensor from numpy.typing import NDArray from typing import Tuple, List, Union -from bittensor.utils import torch +from bittensor.utils.registration import torch, use_torch U32_MAX = 4294967295 U16_MAX = 65535 -def use_torch() -> bool: - return True if os.getenv("USE_TORCH") == "1" else False - - def normalize_max_weight( x: Union[NDArray[np.float32], "torch.FloatTensor"], limit: float = 0.1 ) -> Union[NDArray[np.float32], "torch.FloatTensor"]: diff --git a/tests/integration_tests/test_metagraph_integration.py b/tests/integration_tests/test_metagraph_integration.py index c7396a5a9c..5dbb9ddfc1 100644 --- a/tests/integration_tests/test_metagraph_integration.py +++ b/tests/integration_tests/test_metagraph_integration.py @@ -58,7 +58,6 @@ def test_load_sync_save(self): self.metagraph.save() def test_load_sync_save_from_torch(self): - # os.environ["USE_TORCH"] = "1" self.metagraph.sync(lite=True, subtensor=self.sub) def deprecated_save_torch(metagraph): @@ -74,7 +73,6 @@ def deprecated_save_torch(metagraph): deprecated_save_torch(self.metagraph) self.metagraph.load() - # del os.environ["USE_TORCH"] def test_state_dict(self): self.metagraph.load() From 02968f8d0298dac146ba78b63ecd0d24552e6704 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 21:29:22 +0200 Subject: [PATCH 15/18] Flake8 and Black --- bittensor/chain_data.py | 1 - bittensor/commands/root.py | 5 +++-- bittensor/extrinsics/registration.py | 1 - bittensor/extrinsics/set_weights.py | 1 - bittensor/tensor.py | 9 ++------- example.env | 6 ++++++ tests/unit_tests/extrinsics/test_root.py | 1 - tests/unit_tests/test_dendrite.py | 2 +- tests/unit_tests/utils/test_utils.py | 1 - 9 files changed, 12 insertions(+), 15 deletions(-) create mode 100644 example.env diff --git a/bittensor/chain_data.py b/bittensor/chain_data.py index 11fed3ec67..548fa40ede 100644 --- a/bittensor/chain_data.py +++ b/bittensor/chain_data.py @@ -15,7 +15,6 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. import bittensor -import os import json from enum import Enum from dataclasses import dataclass, asdict diff --git a/bittensor/commands/root.py b/bittensor/commands/root.py index bd819179d5..75c6cb15f6 100644 --- a/bittensor/commands/root.py +++ b/bittensor/commands/root.py @@ -15,7 +15,6 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -import os import re import typing import argparse @@ -536,7 +535,9 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"): else np.array(matched_netuids, dtype=np.int64) ) - matched_weights = [float(weight) for weight in re.split(r"[ ,]+", cli.config.weights)] + matched_weights = [ + float(weight) for weight in re.split(r"[ ,]+", cli.config.weights) + ] weights = ( torch.tensor(matched_weights, dtype=torch.float32) if use_torch() diff --git a/bittensor/extrinsics/registration.py b/bittensor/extrinsics/registration.py index fd2fe52675..193c98238c 100644 --- a/bittensor/extrinsics/registration.py +++ b/bittensor/extrinsics/registration.py @@ -17,7 +17,6 @@ # DEALINGS IN THE SOFTWARE. import bittensor -import os import time from rich.prompt import Confirm from typing import List, Union, Optional, Tuple diff --git a/bittensor/extrinsics/set_weights.py b/bittensor/extrinsics/set_weights.py index 928ae85b3e..2a8c972aa3 100644 --- a/bittensor/extrinsics/set_weights.py +++ b/bittensor/extrinsics/set_weights.py @@ -19,7 +19,6 @@ import bittensor import logging -import os import numpy as np from numpy.typing import NDArray from rich.prompt import Confirm diff --git a/bittensor/tensor.py b/bittensor/tensor.py index f6432e82ce..d9d9902f47 100644 --- a/bittensor/tensor.py +++ b/bittensor/tensor.py @@ -1,7 +1,6 @@ # The MIT License (MIT) # Copyright © 2021 Yuma Rao # Copyright © 2022 Opentensor Foundation -import os # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated # documentation files (the “Software”), to deal in the Software without restriction, including without limitation @@ -124,9 +123,7 @@ def cast_shape(raw: Union[None, List[int], str]) -> str: class tensor: def __new__(cls, tensor: Union[list, np.ndarray, "torch.Tensor"]): if isinstance(tensor, list) or isinstance(tensor, np.ndarray): - tensor = ( - torch.tensor(tensor) if use_torch() else np.array(tensor) - ) + tensor = torch.tensor(tensor) if use_torch() else np.array(tensor) return Tensor.serialize(tensor=tensor) @@ -151,9 +148,7 @@ def tolist(self) -> List[object]: def numpy(self) -> "numpy.ndarray": return ( - self.deserialize().detach().numpy() - if use_torch() - else self.deserialize() + self.deserialize().detach().numpy() if use_torch() else self.deserialize() ) def deserialize(self) -> Union["np.ndarray", "torch.Tensor"]: diff --git a/example.env b/example.env new file mode 100644 index 0000000000..de5fb400ed --- /dev/null +++ b/example.env @@ -0,0 +1,6 @@ +# To use Torch functionality in bittensor, you must set the USE_TORCH flag to 1: +USE_TORCH=1 + +# If set to 0 (or anything else), you will use the numpy functions. +# This is generally what you want unless you have a specific reason for using torch +# such as POW registration or legacy interoperability. \ No newline at end of file diff --git a/tests/unit_tests/extrinsics/test_root.py b/tests/unit_tests/extrinsics/test_root.py index a8de9e21c1..131ca2303d 100644 --- a/tests/unit_tests/extrinsics/test_root.py +++ b/tests/unit_tests/extrinsics/test_root.py @@ -1,4 +1,3 @@ -import os import pytest from unittest.mock import MagicMock, patch from bittensor.subtensor import subtensor as Subtensor diff --git a/tests/unit_tests/test_dendrite.py b/tests/unit_tests/test_dendrite.py index 09219816e8..61011457a1 100644 --- a/tests/unit_tests/test_dendrite.py +++ b/tests/unit_tests/test_dendrite.py @@ -21,7 +21,7 @@ import pytest import typing import bittensor -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock from tests.helpers import _get_mock_wallet from bittensor.synapse import TerminalInfo diff --git a/tests/unit_tests/utils/test_utils.py b/tests/unit_tests/utils/test_utils.py index 88938ed8f4..b03ab6e99c 100644 --- a/tests/unit_tests/utils/test_utils.py +++ b/tests/unit_tests/utils/test_utils.py @@ -19,7 +19,6 @@ import numpy as np import bittensor.utils.weight_utils as weight_utils -import bittensor.utils as unbiased_topk import pytest From 573fda515c1df8658a8c45d253947171cf1bb55d Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 21:56:31 +0200 Subject: [PATCH 16/18] Flake8 and README update. --- README.md | 5 +++++ bittensor/utils/weight_utils.py | 1 - 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0d3a8e6fd3..3c6193247a 100644 --- a/README.md +++ b/README.md @@ -221,6 +221,11 @@ source ~/.bashrc # Reload Bash configuration to take effect # The Bittensor Package The bittensor package contains data structures for interacting with the bittensor ecosystem, writing miners, validators and querying the network. Additionally, it provides many utilities for efficient serialization of Tensors over the wire, performing data analysis of the network, and other useful utilities. +In the 7.0.0 release, we have removed `torch` by default. However, you can still use `torch` by setting the environment variable +`USE_TORCH=1` and making sure that you have installed the `torch` library. +You can install `torch` by running `pip install bittensor[torch]` (if installing via PyPI), or by running `pip install -e ".[torch]"` (if installing from source). +We will not be adding any new functionality based on torch. + Wallet: Interface over locally stored bittensor hot + coldkey styled wallets. ```python import bittensor diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 8499a950ff..109951e14b 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -18,7 +18,6 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. -import os import numpy as np import bittensor from numpy.typing import NDArray From 419551529751b8d95a4055e4055fe6fb6c5b2254 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 22:26:33 +0200 Subject: [PATCH 17/18] Updated error msg for torch, added error msg for running the faucet command without torch. --- bittensor/extrinsics/registration.py | 9 +++++---- bittensor/subtensor.py | 3 ++- bittensor/utils/registration.py | 8 ++++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/bittensor/extrinsics/registration.py b/bittensor/extrinsics/registration.py index 193c98238c..8be4963180 100644 --- a/bittensor/extrinsics/registration.py +++ b/bittensor/extrinsics/registration.py @@ -341,7 +341,7 @@ def run_faucet_extrinsic( num_processes: Optional[int] = None, update_interval: Optional[int] = None, log_verbose: bool = False, -) -> bool: +) -> Tuple[bool, str]: r"""Runs a continual POW to get a faucet of TAO on the test net. Args: @@ -378,10 +378,11 @@ def run_faucet_extrinsic( subtensor.network, ) ): - return False + return False, "" if not use_torch(): - return False + torch.error() + return False, "Requires torch" # Unlock coldkey wallet.coldkey @@ -401,7 +402,7 @@ def run_faucet_extrinsic( if not torch.cuda.is_available(): if prompt: bittensor.__console__.print("CUDA is not available.") - return False + return False, "CUDA is not available." pow_result: Optional[POWSolution] = create_pow( subtensor, wallet, diff --git a/bittensor/subtensor.py b/bittensor/subtensor.py index c6ce9bd370..409549117c 100644 --- a/bittensor/subtensor.py +++ b/bittensor/subtensor.py @@ -1052,7 +1052,7 @@ def run_faucet( This is for testnet ONLY and is disabled currently. You must build your own staging subtensor chain with the ``--features pow-faucet`` argument to enable this. """ - return run_faucet_extrinsic( + result, _ = run_faucet_extrinsic( subtensor=self, wallet=wallet, wait_for_inclusion=wait_for_inclusion, @@ -1067,6 +1067,7 @@ def run_faucet( update_interval=update_interval, log_verbose=log_verbose, ) + return result def burned_register( self, diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py index 8ff155dd2b..ff3816ddbb 100644 --- a/bittensor/utils/registration.py +++ b/bittensor/utils/registration.py @@ -36,9 +36,13 @@ def __init__(self): @staticmethod def _error(): bittensor.logging.warning( - "This command requires torch. Please install torch package." + "This command requires torch. You can install torch for bittensor" + ' with `pip install bittensor[torch]` or `pip install ".[torch]"`' + " if installing from source, and then run the command with USE_TORCH=1 {command}" ) - raise ImportError + + def error(self): + self._error() def _transform(self): try: From 56855fbaa626d9f7b7432fb4ef10b07a5e7585e2 Mon Sep 17 00:00:00 2001 From: bhimes Date: Tue, 21 May 2024 23:28:29 +0200 Subject: [PATCH 18/18] Updated tests to conform to the correct type structure of the method. --- tests/unit_tests/extrinsics/test_registration.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/extrinsics/test_registration.py b/tests/unit_tests/extrinsics/test_registration.py index 9c065207b8..861ce6b462 100644 --- a/tests/unit_tests/extrinsics/test_registration.py +++ b/tests/unit_tests/extrinsics/test_registration.py @@ -107,7 +107,9 @@ def test_run_faucet_extrinsic_happy_path( # Assert if isinstance(result, tuple): assert result[0] == expected - mock_subtensor.substrate.submit_extrinsic.assert_called() + if result[0] is True: + # Checks only if successful + mock_subtensor.substrate.submit_extrinsic.assert_called() else: assert result == expected mock_subtensor.get_balance.assert_called_with("mock_address") @@ -145,7 +147,7 @@ def test_run_faucet_extrinsic_edge_cases( ) # Assert - assert result == expected + assert result[0] == expected @pytest.mark.parametrize(