Skip to content

Commit

Permalink
Merge pull request #1904 from opentensor/feature/add-back-compatibili…
Browse files Browse the repository at this point in the history
…ty-with-torch/thewhaleking

Add back compatibility with torch
  • Loading branch information
thewhaleking authored May 21, 2024
2 parents bfef073 + 56855fb commit 9653271
Show file tree
Hide file tree
Showing 21 changed files with 1,014 additions and 333 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ source ~/.bashrc # Reload Bash configuration to take effect
# The Bittensor Package
The bittensor package contains data structures for interacting with the bittensor ecosystem, writing miners, validators and querying the network. Additionally, it provides many utilities for efficient serialization of Tensors over the wire, performing data analysis of the network, and other useful utilities.
In the 7.0.0 release, we have removed `torch` by default. However, you can still use `torch` by setting the environment variable
`USE_TORCH=1` and making sure that you have installed the `torch` library.
You can install `torch` by running `pip install bittensor[torch]` (if installing via PyPI), or by running `pip install -e ".[torch]"` (if installing from source).
We will not be adding any new functionality based on torch.
Wallet: Interface over locally stored bittensor hot + coldkey styled wallets.
```python
import bittensor
Expand Down
159 changes: 138 additions & 21 deletions bittensor/chain_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
import bittensor

import json
from enum import Enum
from dataclasses import dataclass, asdict
Expand All @@ -27,6 +26,7 @@

from .utils import networking as net, U16_MAX, U16_NORMALIZED_FLOAT
from .utils.balance import Balance
from .utils.registration import torch, use_torch

custom_rpc_type_registry = {
"types": {
Expand Down Expand Up @@ -264,14 +264,43 @@ def from_neuron_info(cls, neuron_info: dict) -> "AxonInfo":
coldkey=neuron_info["coldkey"],
)

def to_parameter_dict(self) -> dict[str, Union[int, str]]:
r"""Returns a dict of the subnet info."""
return self.__dict__
def _to_parameter_dict(
self, return_type: str
) -> Union[dict[str, Union[int, str]], "torch.nn.ParameterDict"]:
if return_type == "torch":
return torch.nn.ParameterDict(self.__dict__)
else:
return self.__dict__

def to_parameter_dict(
self,
) -> Union[dict[str, Union[int, str]], "torch.nn.ParameterDict"]:
"""Returns a torch tensor or dict of the subnet info, depending on the USE_TORCH flag set"""
if use_torch():
return self._to_parameter_dict("torch")
else:
return self._to_parameter_dict("numpy")

@classmethod
def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "AxonInfo":
r"""Returns an axon_info object from a parameter_dict."""
return cls(**parameter_dict)
def _from_parameter_dict(
cls,
parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"],
return_type: str,
) -> "AxonInfo":
if return_type == "torch":
return cls(**dict(parameter_dict))
else:
return cls(**parameter_dict)

@classmethod
def from_parameter_dict(
cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"]
) -> "AxonInfo":
"""Returns an axon_info object from a torch parameter_dict or a parameter dict."""
if use_torch():
return cls._from_parameter_dict(parameter_dict, "torch")
else:
return cls._from_parameter_dict(parameter_dict, "numpy")


class ChainDataType(Enum):
Expand Down Expand Up @@ -980,15 +1009,42 @@ def fix_decoded_values(cls, decoded: Dict) -> "SubnetInfo":
owner_ss58=ss58_encode(decoded["owner"], bittensor.__ss58_format__),
)

def to_parameter_dict(self) -> dict[str, Any]:
r"""Returns a dict of the subnet info."""
return self.__dict__
def _to_parameter_dict(
self, return_type: str
) -> Union[dict[str, Any], "torch.nn.ParameterDict"]:
if return_type == "torch":
return torch.nn.ParameterDict(self.__dict__)
else:
return self.__dict__

def to_parameter_dict(self) -> Union[dict[str, Any], "torch.nn.ParameterDict"]:
"""Returns a torch tensor or dict of the subnet info."""
if use_torch():
return self._to_parameter_dict("torch")
else:
return self._to_parameter_dict("numpy")

@classmethod
def _from_parameter_dict_torch(
cls, parameter_dict: "torch.nn.ParameterDict"
) -> "SubnetInfo":
"""Returns a SubnetInfo object from a torch parameter_dict."""
return cls(**dict(parameter_dict))

@classmethod
def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "SubnetInfo":
def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "SubnetInfo":
r"""Returns a SubnetInfo object from a parameter_dict."""
return cls(**parameter_dict)

@classmethod
def from_parameter_dict(
cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"]
) -> "SubnetInfo":
if use_torch():
return cls._from_parameter_dict_torch(parameter_dict)
else:
return cls._from_parameter_dict_numpy(parameter_dict)


@dataclass
class SubnetHyperparameters:
Expand Down Expand Up @@ -1074,15 +1130,46 @@ def fix_decoded_values(cls, decoded: Dict) -> "SubnetHyperparameters":
difficulty=decoded["difficulty"],
)

def to_parameter_dict(self) -> dict[str, Union[int, float, bool]]:
r"""Returns a dict of the subnet hyperparameters."""
return self.__dict__
def _to_parameter_dict_torch(
self, return_type: str
) -> Union[dict[str, Union[int, float, bool]], "torch.nn.ParameterDict"]:
if return_type == "torch":
return torch.nn.ParameterDict(self.__dict__)
else:
return self.__dict__

def to_parameter_dict(
self,
) -> Union[dict[str, Union[int, float, bool]], "torch.nn.ParameterDict"]:
"""Returns a torch tensor or dict of the subnet hyperparameters."""
if use_torch():
return self._to_parameter_dict_torch("torch")
else:
return self._to_parameter_dict_torch("numpy")

@classmethod
def _from_parameter_dict_torch(
cls, parameter_dict: "torch.nn.ParameterDict"
) -> "SubnetHyperparameters":
"""Returns a SubnetHyperparameters object from a torch parameter_dict."""
return cls(**dict(parameter_dict))

@classmethod
def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "SubnetInfo":
r"""Returns a SubnetHyperparameters object from a parameter_dict."""
def _from_parameter_dict_numpy(
cls, parameter_dict: dict[str, Any]
) -> "SubnetHyperparameters":
"""Returns a SubnetHyperparameters object from a parameter_dict."""
return cls(**parameter_dict)

@classmethod
def from_parameter_dict(
cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"]
) -> "SubnetHyperparameters":
if use_torch():
return cls._from_parameter_dict_torch(parameter_dict)
else:
return cls._from_parameter_dict_numpy(parameter_dict)


@dataclass
class IPInfo:
Expand Down Expand Up @@ -1137,15 +1224,45 @@ def fix_decoded_values(cls, decoded: Dict) -> "IPInfo":
protocol=decoded["ip_type_and_protocol"] & 0xF,
)

def to_parameter_dict(self) -> dict[str, Union[str, int]]:
r"""Returns a dict of the subnet ip info."""
return self.__dict__
def _to_parameter_dict(
self, return_type: str
) -> Union[dict[str, Union[str, int]], "torch.nn.ParameterDict"]:
"""Returns a torch tensor of the subnet info."""
if return_type == "torch":
return torch.nn.ParameterDict(self.__dict__)
else:
return self.__dict__

def to_parameter_dict(
self,
) -> Union[dict[str, Union[str, int]], "torch.nn.ParameterDict"]:
"""Returns a torch tensor or dict of the subnet IP info."""
if use_torch():
return self._to_parameter_dict("torch")
else:
return self._to_parameter_dict("numpy")

@classmethod
def _from_parameter_dict_torch(
cls, parameter_dict: "torch.nn.ParameterDict"
) -> "IPInfo":
"""Returns a IPInfo object from a torch parameter_dict."""
return cls(**dict(parameter_dict))

@classmethod
def from_parameter_dict(cls, parameter_dict: dict[str, Any]) -> "IPInfo":
r"""Returns a IPInfo object from a parameter_dict."""
def _from_parameter_dict_numpy(cls, parameter_dict: dict[str, Any]) -> "IPInfo":
"""Returns a IPInfo object from a parameter_dict."""
return cls(**parameter_dict)

@classmethod
def from_parameter_dict(
cls, parameter_dict: Union[dict[str, Any], "torch.nn.ParameterDict"]
) -> "IPInfo":
if use_torch():
return cls._from_parameter_dict_torch(parameter_dict)
else:
return cls._from_parameter_dict_numpy(parameter_dict)


# Senate / Proposal data

Expand Down
35 changes: 27 additions & 8 deletions bittensor/commands/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# DEALINGS IN THE SOFTWARE.

import re
import numpy as np
import typing
import argparse
import numpy as np
Expand All @@ -25,6 +24,7 @@
from rich.prompt import Prompt
from rich.table import Table
from .utils import get_delegates_details, DelegatesDetails
from bittensor.utils.registration import torch, use_torch

from . import defaults

Expand Down Expand Up @@ -301,7 +301,11 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
f"Boosting weight for netuid {cli.config.netuid} from {prev_weight} -> {new_weight}"
)
my_weights[cli.config.netuid] = new_weight
all_netuids = np.arange(len(my_weights))
all_netuids = (
torch.tensor(list(range(len(my_weights))))
if use_torch()
else np.arange(len(my_weights))
)

bittensor.__console__.print("Setting root weights...")
subtensor.root_set_weights(
Expand Down Expand Up @@ -419,7 +423,11 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
my_weights = root.weights[my_uid]
my_weights[cli.config.netuid] -= cli.config.amount
my_weights[my_weights < 0] = 0 # Ensure weights don't go negative
all_netuids = np.arange(len(my_weights))
all_netuids = (
torch.tensor(list(range(len(my_weights))))
if use_torch()
else np.arange(len(my_weights))
)

subtensor.root_set_weights(
wallet=wallet,
Expand Down Expand Up @@ -520,12 +528,23 @@ def _run(cli: "bittensor.cli", subtensor: "bittensor.subtensor"):
cli.config.weights = Prompt.ask(f"Enter weights (e.g. {example})")

# Parse from string
netuids = np.array(
list(map(int, re.split(r"[ ,]+", cli.config.netuids))), dtype=np.int64
matched_netuids = list(map(int, re.split(r"[ ,]+", cli.config.netuids)))
netuids = (
torch.tensor(matched_netuids, dtype=torch.long)
if use_torch()
else np.array(matched_netuids, dtype=np.int64)
)
weights = np.array(
list(map(float, re.split(r"[ ,]+", cli.config.weights))),
dtype=np.float32,

matched_weights = [
float(weight) for weight in re.split(r"[ ,]+", cli.config.weights)
]
weights = (
torch.tensor(matched_weights, dtype=torch.float32)
if use_torch()
else np.array(
matched_weights,
dtype=np.float32,
)
)

# Run the set weights operation.
Expand Down
8 changes: 3 additions & 5 deletions bittensor/commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import bittensor
import requests
from bittensor.utils.registration import maybe_get_torch
from bittensor.utils.registration import torch
from typing import List, Dict, Any, Optional
from rich.prompt import Confirm, PromptBase
from dataclasses import dataclass
Expand Down Expand Up @@ -78,11 +78,9 @@ def check_netuid_set(

def check_for_cuda_reg_config(config: "bittensor.config") -> None:
"""Checks, when CUDA is available, if the user would like to register with their CUDA device."""

torch = maybe_get_torch()
if torch is not None and torch.cuda.is_available():
if torch and torch.cuda.is_available():
if not config.no_prompt:
if config.pow_register.cuda.get("use_cuda") == None: # flag not set
if config.pow_register.cuda.get("use_cuda") is None: # flag not set
# Ask about cuda registration only if a CUDA device is available.
cuda = Confirm.ask("Detected CUDA device, use CUDA for registration?\n")
config.pow_register.cuda.use_cuda = cuda
Expand Down
31 changes: 25 additions & 6 deletions bittensor/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
import time
import aiohttp
import bittensor
from typing import Union, Optional, List, Union, AsyncGenerator, Any
from typing import Optional, List, Union, AsyncGenerator, Any
from bittensor.utils.registration import torch, use_torch


class dendrite:
class DendriteMixin:
"""
The Dendrite class represents the abstracted implementation of a network client module.
Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(
The user's wallet or keypair used for signing messages. Defaults to ``None``, in which case a new :func:`bittensor.wallet().hotkey` is generated and used.
"""
# Initialize the parent class
super(dendrite, self).__init__()
super(DendriteMixin, self).__init__()

# Unique identifier for the instance
self.uuid = str(uuid.uuid1())
Expand All @@ -121,9 +122,6 @@ def __init__(

self._session: Optional[aiohttp.ClientSession] = None

async def __call__(self, *args, **kwargs):
return await self.forward(*args, **kwargs)

@property
async def session(self) -> aiohttp.ClientSession:
"""
Expand Down Expand Up @@ -808,3 +806,24 @@ def __del__(self):
del dendrite # This will implicitly invoke the __del__ method and close the session.
"""
self.close_session()


# For back-compatibility with torch
BaseModel: Union["torch.nn.Module", object] = torch.nn.Module if use_torch() else object


class dendrite(DendriteMixin, BaseModel): # type: ignore
def __init__(
self, wallet: Optional[Union[bittensor.wallet, bittensor.Keypair]] = None
):
if use_torch():
torch.nn.Module.__init__(self)
DendriteMixin.__init__(self, wallet)


if not use_torch():

async def call(self, *args, **kwargs):
return await self.forward(*args, **kwargs)

dendrite.__call__ = call
Loading

0 comments on commit 9653271

Please sign in to comment.