diff --git a/.github/workflows/pytests.yml b/.github/workflows/pytests.yml index 83908cb0..e7300a85 100644 --- a/.github/workflows/pytests.yml +++ b/.github/workflows/pytests.yml @@ -79,4 +79,4 @@ jobs: # run tests #---------------------------------------------- - name: Run Python SDK tests - run: poetry run pytest -s + run: poetry run pytest -s tests diff --git a/nibiru/__init__.py b/nibiru/__init__.py index 59a2f2ab..baf7575b 100644 --- a/nibiru/__init__.py +++ b/nibiru/__init__.py @@ -16,11 +16,11 @@ ProtobufMessage = google.protobuf.message.Message -import nibiru.common # noqa import nibiru.msg # noqa -from nibiru.client import GrpcClient # noqa -from nibiru.common import Coin, Direction, PoolAsset, Side, TxConfig, TxType # noqa +import nibiru.pytypes # noqa +from nibiru.grpc_client import GrpcClient # noqa from nibiru.network import Network # noqa +from nibiru.pytypes import Coin, Direction, PoolAsset, Side, TxConfig, TxType # noqa from nibiru.sdk import Sdk # noqa from nibiru.transaction import Transaction # noqa from nibiru.wallet import Address, PrivateKey, PublicKey # noqa diff --git a/nibiru/client.py b/nibiru/grpc_client.py similarity index 100% rename from nibiru/client.py rename to nibiru/grpc_client.py diff --git a/nibiru/msg/bank.py b/nibiru/msg/bank.py index 0c24f914..67d31b7a 100644 --- a/nibiru/msg/bank.py +++ b/nibiru/msg/bank.py @@ -5,7 +5,7 @@ from nibiru_proto.proto.cosmos.distribution.v1beta1 import tx_pb2 as tx_pb from nibiru_proto.proto.cosmos.staking.v1beta1 import tx_pb2 as staking_pb -from nibiru.common import Coin, PythonMsg +from nibiru.pytypes import Coin, PythonMsg @dataclasses.dataclass diff --git a/nibiru/msg/dex.py b/nibiru/msg/dex.py index 09471fd1..ba0ebcad 100644 --- a/nibiru/msg/dex.py +++ b/nibiru/msg/dex.py @@ -1,10 +1,10 @@ import dataclasses -from typing import List +from typing import List, Union from nibiru_proto.proto.dex.v1 import pool_pb2 as pool_tx_pb from nibiru_proto.proto.dex.v1 import tx_pb2 as pb -from nibiru.common import Coin, PoolAsset, PoolType, PythonMsg +from nibiru.pytypes import Coin, PoolAsset, PoolType, PythonMsg @dataclasses.dataclass @@ -62,9 +62,11 @@ class MsgJoinPool(PythonMsg): sender: str pool_id: int - tokens: List[Coin] + tokens: Union[Coin, List[Coin]] def to_pb(self) -> pb.MsgJoinPool: + if isinstance(self.tokens, Coin): + self.tokens = [self.tokens] return pb.MsgJoinPool( sender=self.sender, pool_id=self.pool_id, @@ -85,7 +87,7 @@ class MsgExitPool(PythonMsg): sender: str pool_id: int - pool_shares: List[Coin] + pool_shares: Coin def to_pb(self) -> pb.MsgExitPool: return pb.MsgExitPool( diff --git a/nibiru/msg/perp.py b/nibiru/msg/perp.py index 7510a14d..fdeebeb8 100644 --- a/nibiru/msg/perp.py +++ b/nibiru/msg/perp.py @@ -3,7 +3,7 @@ from nibiru_proto.proto.perp.v1 import state_pb2 as state_pb from nibiru_proto.proto.perp.v1 import tx_pb2 as pb -from nibiru.common import Coin, PythonMsg, Side +from nibiru.pytypes import Coin, PythonMsg, Side from nibiru.utils import to_sdk_dec, to_sdk_int diff --git a/nibiru/msg/pricefeed.py b/nibiru/msg/pricefeed.py index 8acf6790..9437befe 100644 --- a/nibiru/msg/pricefeed.py +++ b/nibiru/msg/pricefeed.py @@ -3,7 +3,7 @@ from nibiru_proto.proto.pricefeed import tx_pb2 as pb -from nibiru.common import PythonMsg +from nibiru.pytypes import PythonMsg from nibiru.utils import to_sdk_dec, toPbTimestamp diff --git a/nibiru/network.py b/nibiru/network.py index 234b233d..a1676fac 100644 --- a/nibiru/network.py +++ b/nibiru/network.py @@ -12,20 +12,48 @@ @dataclasses.dataclass class Network: + """A representation of a Nibiru network based on its Tendermint RPC, gRPC, + and LCD (REST) endpoints. A 'Network' instance enables interactions with a + running blockchain. + + Attributes: + lcd_endpoint (str): . + grpc_endpoint (str): . + tendermint_rpc_endpoint (str): . + chain_id (str): . + websocket_endpoint (str): . + env (Optional[str]): TODO docs + fee_denom (Optional[str]): Denom for the coin used to pay gas fees. Defaults to "unibi". + + Methods: + customnet: A custom Nibiru network based on environment variables. + Defaults to localnet. + devnet: A development testnet environment that runs the latest release or + pre-release from the nibiru repo. Defaults to 'nibiru-devnet-1'. + localnet: The default local network created by running 'make localnet' in + the nibiru repo. + testnet: A stable testnet environment with public community members. + Think of this as out practice mainnet. Defaults to 'nibiru-testnet-1'. + mainnet: NotImplemented. + + Examples: + >>> from nibiru import Network + >>> network = Network.devnet(2) + >>> network.is_insecure + True + """ + lcd_endpoint: str grpc_endpoint: str tendermint_rpc_endpoint: str chain_id: str - env: str websocket_endpoint: str + env: str = "custom" fee_denom: str = "unibi" - def __post_init__(self): - """ - Update the env value if the dataclass was created without one. - """ - if self.env == "": - self.env = "custom" + @property + def is_insecure(self) -> bool: + return not ("https" in self.tendermint_rpc_endpoint) @classmethod def customnet(cls) -> "Network": @@ -74,8 +102,8 @@ def customnet(cls) -> "Network": @classmethod def testnet(cls, chain_num: int = 1) -> "Network": """ - Testnet is a network open to invited validators. It is more stable than devnet and provides a faucet to get some - funds + Testnet is a network open to invited validators. It is more stable than + devnet and provides a faucet to get some funds Args: chain_num (int): Testnet number diff --git a/nibiru/pytypes/__init__.py b/nibiru/pytypes/__init__.py new file mode 100644 index 00000000..e2132e6e --- /dev/null +++ b/nibiru/pytypes/__init__.py @@ -0,0 +1,16 @@ +# These import statements export the types to 'nibiru.pytypes'. + +from nibiru.pytypes.common import ( # noqa # TODO move constants to a constants.py file.; noqa + GAS_PRICE, + MAX_MEMO_CHARACTERS, + Coin, + Direction, + PoolAsset, + PoolType, + PythonMsg, + Side, + TxConfig, + TxType, +) +from nibiru.pytypes.event import Event, RawEvent, TxLogEvents # noqa +from nibiru.pytypes.tx_resp import RawTxResp, TxResp # noqa diff --git a/nibiru/common.py b/nibiru/pytypes/common.py similarity index 67% rename from nibiru/common.py rename to nibiru/pytypes/common.py index b45b6af7..5ef87023 100644 --- a/nibiru/common.py +++ b/nibiru/pytypes/common.py @@ -1,5 +1,5 @@ import abc -from dataclasses import dataclass +import dataclasses from enum import Enum from nibiru_proto.proto.cosmos.base.v1beta1 import coin_pb2 as cosmos_base_coin_pb @@ -52,7 +52,7 @@ class Direction(Enum): REMOVE = 2 -@dataclass +@dataclasses.dataclass class Coin: amount: float denom: str @@ -61,35 +61,32 @@ def _generate_proto_object(self): return cosmos_base_coin_pb.Coin(amount=str(self.amount), denom=self.denom) -@dataclass +@dataclasses.dataclass class PoolAsset: token: Coin weight: float +@dataclasses.dataclass class TxConfig: - def __init__( - self, - gas_wanted: int = 0, - gas_multiplier: float = 1.25, - gas_price: float = 0, - tx_type: TxType = TxType.ASYNC, - ): - """ - The TxConfig object allows to customize the behavior of the Sdk interface when a transaction is sent. - - Args: - gas_wanted (int, optional): Set the absolute gas_wanted to be used. Defaults to 0. - gas_multiplier (float, optional): Set the gas multiplier that's being applied to the estimated gas. - Defaults to 0. If gas_wanted is set this property is ignored. - gas_price (float, optional): Set the gas price used to calculate the fee. Defaults to 0. - tx_type (TxType, optional): Configure how to execute the tx. Defaults to TxType.ASYNC. - """ + """ + The TxConfig object allows to customize the behavior of the Sdk interface when a transaction is sent. + + Args: + gas_wanted (int, optional): Set the absolute gas_wanted to be used. + Defaults to 0. + gas_multiplier (float, optional): Set the gas multiplier that's being + applied to the estimated gas. If gas_wanted is set, this property + is ignored. Defaults to 0. + gas_price (float, optional): Set the gas price used to calculate the fee. + Defaults to 0.25. + tx_type (TxType, optional): Configure how to execute the tx. Defaults to TxType.ASYNC. + """ - self.gas_multiplier = gas_multiplier - self.gas_wanted = gas_wanted - self.gas_price = gas_price - self.tx_type = tx_type + gas_wanted: int = 0 + gas_multiplier: float = 1.25 + gas_price: float = 0.25 + tx_type: TxType = TxType.ASYNC class PythonMsg(abc.ABC): diff --git a/nibiru/pytypes/event.py b/nibiru/pytypes/event.py new file mode 100644 index 00000000..d2751533 --- /dev/null +++ b/nibiru/pytypes/event.py @@ -0,0 +1,118 @@ +import collections +import pprint +from typing import Dict, List + + +class RawEvent(collections.abc.MutableMapping): + """Dictionary representing a Tendermint event. In the raw TxOutput of a + successful transaciton, it's the value at + + ```python + tx_output['rawLog'][0]['events'] + ``` + + ### Keys (KeyType): + - attributes (List[Dict[str,str]]) + - type (str) + + ### Example: + ```python + {'attributes': [ + {'key': 'recipient', 'value': 'nibi1uvu52rxwqj5ndmm59y6atvx33mru9xrz6sqekr'}, + {'key': 'sender', 'value': 'nibi1zaavvzxez0elundtn32qnk9lkm8kmcsz44g7xl'}, + {'key': 'amount', 'value': '7unibi,70unusd'}], + 'type': 'transfer'} + ``` + """ + + +class Event: + """A Tendermint event. An event contains a type and set of attributes. + Events allow application developers to attach additional information to the + 'ResponseBeginBlock', 'ResponseEndBlock', 'ResponseCheckTx', and 'ResponseDeliverTx' + functions in the ABCI (application blockchain interface). + + In the Tendermint protobuf, the hard definition is: + + ```proto + message Event { + string type = 1; + repeated EventAttribute attributes = 2; + } + message EventAttribute { + bytes key = 1; + bytes value = 2; + bool index = 3; + } + ``` + + - Ref: [cosmos-sdk/types/events.go](https://github.com/cosmos/cosmos-sdk/blob/93abfdd21d9892550da315b10308519b43fb1775/types/events.go#L221) + - Ref: [tendermint/tendermint/proto/tendermint/abci/types.proto](https://github.com/tendermint/tendermint/blob/a6dd0d270abc3c01f223eedee44d8b285ae273f6/proto/tendermint/abci/types.proto) + """ + + type: str + attrs: Dict[str, str] + + def __init__(self, raw_event: RawEvent): + self.type = raw_event["type"] + self.attrs = self.parse_attributes(raw_event["attributes"]) + + @staticmethod + def parse_attributes(raw_attributes: List[Dict[str, str]]) -> Dict[str, str]: + try: + attributes: dict[str, str] = { + kv_dict['key']: kv_dict['value'] for kv_dict in raw_attributes + } + return attributes + except: + raise Exception( + f"failed to parse raw attributes:\n{pprint.pformat(raw_attributes)}" + ) + + def __repr__(self) -> str: + return f"Event(type={self.type}, attrs={self.attrs})" + + def to_dict(self) -> Dict[str, Dict[str, str]]: + return {self.type: self.attrs} + + +class TxLogEvents: + """An element of 'TxResp.rawLog'. This object contains events and messages. + + Keys (KeyType): + type (str) + attributes (List[EventAttribute]) + + Args: + events_raw (List[RawEvent]) + + Attributes: + events (List[Event]) + msgs (List[str]) + events_raw (List[RawEvent]) + event_types (List[str]) + """ + + events: List[Event] + msgs: List[str] + events_raw: List[RawEvent] + event_types: List[str] + + def __init__(self, events_raw: List[RawEvent] = []): + self.events_raw = events_raw + self.events = [Event(raw_event) for raw_event in events_raw] + self.msgs = self.get_msg_types() + + def get_msg_types(self) -> List[str]: + + msgs = [] + self.event_types = [] + for event in self.events: + self.event_types.append(event.type) + if event.type == "message": + msgs.append(event.attrs["action"]) + return msgs + + def __repr__(self) -> str: + self_as_dict = dict(msgs=self.msgs, events=[e.to_dict() for e in self.events]) + return pprint.pformat(self_as_dict, indent=2) diff --git a/nibiru/pytypes/tx_resp.py b/nibiru/pytypes/tx_resp.py new file mode 100644 index 00000000..ec38c3da --- /dev/null +++ b/nibiru/pytypes/tx_resp.py @@ -0,0 +1,110 @@ +import dataclasses +from typing import Any, Dict, List, Union + +from nibiru import utils +from nibiru.pytypes import event + + +@dataclasses.dataclass +class TxResp: + """ + A 'TxResp' represents the response payload from a successful transaction. + + The 'TxResponse' type is defined in [cosmos-sdk/types/abci.pb.go](https://github.com/cosmos/cosmos-sdk/blob/v0.45.10/types/abci.pb.go) + + ### Args & Attributes: + + - height (int): block height at which the transaction was committed. + - txhash (str): unique identifier for the transaction + - data (str): Result bytes. + - rawLog (List[TxLogEvents]): Raw output of the SDK application's logger. + Possibly non-deterministic. This output also contains the events emitted + during the processing of the transaction, which is equivalently + - logs (list): Typed output of the SDK application's logger. + Possibly non-deterministic. + - gasWanted (str): Amount of gas units requested for the transaction. + - gasUsed (str): Amount of gas units consumed by the transaction execution. + - events (list): Tendermint events emitted by processing the transaction. + The events in this attribute include those emitted by both from + the ante handler and the processing of all messages, whereas the + 'rawLog' events are only those emitted when processing messages (with + additional metadata). + - _raw (RawTxResp): The unprocessed form of the transaction resposnse. + + """ + + height: int + txhash: str + data: str + rawLog: List[event.TxLogEvents] + logs: list + gasWanted: int + gasUsed: int + events: list + _raw: 'RawTxResp' + + @classmethod + def from_raw(cls, raw_tx_resp: 'RawTxResp') -> 'TxResp': + return cls( + height=int(raw_tx_resp["height"]), + txhash=raw_tx_resp["txhash"], + data=raw_tx_resp["data"], + rawLog=[ + event.TxLogEvents(msg_log['events']) + for msg_log in raw_tx_resp["rawLog"] + ], + logs=raw_tx_resp["logs"], + gasWanted=int(raw_tx_resp["gasWanted"]), + gasUsed=int(raw_tx_resp["gasUsed"]), + events=raw_tx_resp["events"], + _raw=raw_tx_resp, + ) + + def __repr__(self) -> str: + repr_body = ", ".join( + [ + f"height={self.height}", + f"txhash={self.txhash}", + f"gasUsed={self.gasUsed}", + f"gasWanted={self.gasWanted}", + f"rawLog={self.rawLog}", + ] + ) + return f"TxResp({repr_body})" + + +# from typing import TypedDict +# class RawTxResp(TypedDict): # not available in Python 3.7 +class RawTxResp(dict): + """Proxy for a 'TypedDict' representing a transaction response. + - The 'TxResponse' type is defined in + [cosmos-sdk/types/abci.pb.go](https://github.com/cosmos/cosmos-sdk/blob/v0.45.10/types/abci.pb.go) + + ### Keys (ValueType): + + - height (str): block height at which the transaction was committed. + - txhash (str): unique identifier for the transaction + - data (str): Result bytes. + - rawLog (list): Raw output of the SDK application's logger. + Possibly non-deterministic. This output also contains the events emitted + during the processing of the transaction, which is equivalently + - logs (list): Typed output of the SDK application's logger. + Possibly non-deterministic. + - gasWanted (str): Amount of gas units requested for the transaction. + - gasUsed (str): Amount of gas units consumed by the transaction execution. + - events (list): Tendermint events emitted by processing the transaction. + The events in this attribute include those emitted by both from + the ante handler and the processing of all messages, whereas the + 'rawLog' events are only those emitted when processing messages (with + additional metadata). + """ + + def __new__(cls, _dict: Dict[str, Any]) -> Dict[str, Union[str, list]]: + """Verifies that the dictionary has the expected keys.""" + keys_wanted = ["height", "txhash", "data", "rawLog", "logs"] + [ + "gasWanted", + "gasUsed", + "events", + ] + utils.dict_keys_must_match(_dict, keys_wanted) + return _dict diff --git a/nibiru/query_clients/dex.py b/nibiru/query_clients/dex.py index 93a61d7c..bc97ac12 100644 --- a/nibiru/query_clients/dex.py +++ b/nibiru/query_clients/dex.py @@ -6,7 +6,7 @@ from nibiru_proto.proto.dex.v1 import query_pb2 as dex_type from nibiru_proto.proto.dex.v1 import query_pb2_grpc as dex_query -from nibiru.common import Coin +from nibiru.pytypes import Coin from nibiru.query_clients.util import QueryClient from nibiru.utils import format_fields_nested, from_sdk_dec_n @@ -121,7 +121,9 @@ def pools(self, **kwargs): should_deserialize=False, ) - output = MessageToDict(proto_output)["pools"] + output: dict = MessageToDict(proto_output).get("pools") + if output is None: + output = {} return format_fields_nested( object=format_fields_nested( diff --git a/nibiru/query_clients/vpool.py b/nibiru/query_clients/vpool.py index fad808d5..c183b00e 100644 --- a/nibiru/query_clients/vpool.py +++ b/nibiru/query_clients/vpool.py @@ -3,7 +3,7 @@ from nibiru_proto.proto.vpool.v1 import query_pb2_grpc as vpool_query from nibiru_proto.proto.vpool.v1.state_pb2 import Direction as pbDirection -from nibiru.common import Direction +from nibiru.pytypes import Direction from nibiru.query_clients.util import QueryClient diff --git a/nibiru/sdk.py b/nibiru/sdk.py index 0498efa6..baef36d3 100644 --- a/nibiru/sdk.py +++ b/nibiru/sdk.py @@ -6,13 +6,13 @@ chain. This object depends on the network and transaction configuration the users want. These objects can be set using the -Network and TxConfig classes respectively inside the nibiru/network.py and nibiru/common.py files. +Network and TxConfig classes respectively inside the nibiru/network.py and nibiru/pytypes files. """ import logging -from nibiru.client import GrpcClient -from nibiru.common import TxConfig +from nibiru.grpc_client import GrpcClient from nibiru.network import Network +from nibiru.pytypes import TxConfig from nibiru.tx import BaseTxClient from nibiru.wallet import PrivateKey @@ -40,11 +40,13 @@ class Sdk: Example :: - sdk = ( - Sdk.authorize(val_mnemonic) - .with_config(tx_config) - .with_network(network, network_insecure) - ) + ```python + sdk = ( + Sdk.authorize(val_mnemonic) + .with_config(tx_config) + .with_network(network, network_insecure) + ) + ``` """ query: GrpcClient @@ -93,7 +95,7 @@ def authorize(cls, key: str = None) -> "Sdk": return self def with_network( - self, network: Network, insecure=False, bypass_version_check: bool = False + self, network: Network, bypass_version_check: bool = False ) -> "Sdk": """ Change the network of the sdk to the specified network. @@ -108,7 +110,7 @@ def with_network( """ self.network = network self._with_query_client( - GrpcClient(self.network, insecure, bypass_version_check) + GrpcClient(network, network.is_insecure, bypass_version_check) ) return self diff --git a/nibiru/transaction.py b/nibiru/transaction.py index a691ea79..d1e8e4f6 100644 --- a/nibiru/transaction.py +++ b/nibiru/transaction.py @@ -6,8 +6,8 @@ from nibiru_proto.proto.cosmos.tx.signing.v1beta1 import signing_pb2 as tx_sign from nibiru_proto.proto.cosmos.tx.v1beta1 import tx_pb2 as cosmos_tx_type -from nibiru.client import GrpcClient -from nibiru.common import MAX_MEMO_CHARACTERS +from nibiru.grpc_client import GrpcClient +from nibiru.pytypes import MAX_MEMO_CHARACTERS from nibiru.wallet import PrivateKey, PublicKey diff --git a/nibiru/tx.py b/nibiru/tx.py index 9fdbc585..09eb276b 100644 --- a/nibiru/tx.py +++ b/nibiru/tx.py @@ -1,15 +1,15 @@ import json import logging from copy import deepcopy -from typing import Any, Dict, List, Union +from typing import Any, List, Union from google.protobuf.json_format import MessageToDict from nibiru_proto.proto.cosmos.base.abci.v1beta1 import abci_pb2 as abci_type from nibiru_proto.proto.cosmos.base.v1beta1 import coin_pb2 as cosmos_base_coin_pb -from nibiru.client import GrpcClient -from nibiru.common import GAS_PRICE, PythonMsg, TxConfig, TxType +from nibiru import pytypes as pt from nibiru.exceptions import SimulationError, TxError +from nibiru.grpc_client import GrpcClient from nibiru.network import Network from nibiru.transaction import Transaction from nibiru.wallet import PrivateKey @@ -21,7 +21,7 @@ def __init__( priv_key: PrivateKey, network: Network, client: GrpcClient, - config: TxConfig, + config: pt.TxConfig, ): self.priv_key = priv_key self.network = network @@ -31,10 +31,10 @@ def __init__( def execute_msgs( self, - msgs: Union[PythonMsg, List[PythonMsg]], + msgs: Union[pt.PythonMsg, List[pt.PythonMsg]], get_sequence_from_node: bool = False, **kwargs, - ) -> Dict[str, Any]: + ) -> pt.RawTxResp: """ Execute a message to broadcast a transaction to the node. Simulate the message to generate the gas estimate and send it to the node. @@ -49,7 +49,8 @@ def execute_msgs( TxError: Raw error log from the blockchain if the response code is nonzero. Returns: - dict[str, Any]: The transaction response as a dict in proto3 JSON format. + Union[RawTxResp, Dict[str, Any]]: The transaction response as a dict + in proto3 JSON format. """ if not isinstance(msgs, list): msgs = [msgs] @@ -89,7 +90,7 @@ def execute_msgs( # Convert raw log into a dictionary tx_output["rawLog"] = json.loads(tx_output.get("rawLog", "{}")) - return tx_output + return pt.RawTxResp(tx_output) except SimulationError as err: if ( @@ -112,7 +113,7 @@ def execute_tx( gas_wanted = conf.gas_wanted elif conf.gas_multiplier > 0: gas_wanted = gas_estimate * conf.gas_multiplier - gas_price = GAS_PRICE if conf.gas_price <= 0 else conf.gas_price + gas_price = pt.GAS_PRICE if conf.gas_price <= 0 else conf.gas_price fee = [ cosmos_base_coin_pb.Coin( @@ -133,10 +134,10 @@ def execute_tx( return self._send_tx(tx_raw_bytes, conf.tx_type) - def _send_tx(self, tx_raw_bytes, tx_type: TxType) -> abci_type.TxResponse: - if tx_type == TxType.SYNC: + def _send_tx(self, tx_raw_bytes, tx_type: pt.TxType) -> abci_type.TxResponse: + if tx_type == pt.TxType.SYNC: return self.client.send_tx_sync_mode(tx_raw_bytes) - elif tx_type == TxType.ASYNC: + elif tx_type == pt.TxType.ASYNC: return self.client.send_tx_async_mode(tx_raw_bytes) return self.client.send_tx_block_mode(tx_raw_bytes) diff --git a/nibiru/utils.py b/nibiru/utils.py index 2fba3d1d..e12f5292 100644 --- a/nibiru/utils.py +++ b/nibiru/utils.py @@ -1,8 +1,9 @@ +import collections import json import logging import sys from datetime import datetime -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union from google.protobuf.timestamp_pb2 import Timestamp @@ -303,3 +304,117 @@ def clean_nested_dict(dictionary: Union[List, Dict, str]) -> Dict: dictionary[key] = value return dictionary + + +# ---------------------------------------------------- +# ---------------------------------------------------- + + +def dict_keys_must_match(dict_: dict, keys: List[str]): + """Asserts that two iterables have the same elements, the same number of + times, without regard to order. + Alias for the 'element_counts_are_equal' function. + + dict_keys_must_match(dict_, keys) + + Example: + - [0, 1, 1] and [1, 0, 1] compare equal. + - [0, 0, 1] and [0, 1] compare unequal. + + """ + assert element_counts_are_equal(dict_.keys(), keys) + + +def element_counts_are_equal( + first: Iterable[Any], second: Iterable[Any] +) -> Optional[bool]: + """Asserts that two iterables have the same elements, the same number of + times, without regard to order. + + Args: + first (Iterable[Any]) + second (Iterable[Any]) + + Returns: + Optional[bool]: "passed" status. If this is True, first and second share + the same element counts. If they don't the function will raise an + AssertionError and return 'None'. + """ + first_seq, second_seq = list(first), list(second) + + passed: Union[bool, None] + try: + first = collections.Counter(first_seq) + second = collections.Counter(second_seq) + except TypeError: + # Handle case with unhashable elements + differences = _count_diff_all_purpose(first_seq, second_seq) + else: + if first == second: + passed = True + return passed + differences = _count_diff_hashable(first_seq, second_seq) + + if differences: + standardMsg = "Element counts were not equal:\n" + lines = ["First has %d, Second has %d: %r" % diff for diff in differences] + diffMsg = "\n".join(lines) + msg = "\n".join([standardMsg, diffMsg]) + passed = False + assert passed, msg + + +_Mismatch = collections.namedtuple("Mismatch", "actual expected value") + + +def _count_diff_all_purpose(actual, expected): + "Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ" + # elements need not be hashable + s, t = list(actual), list(expected) + m, n = len(s), len(t) + NULL = object() + result = [] + for i, elem in enumerate(s): + if elem is NULL: + continue + cnt_s = cnt_t = 0 + for j in range(i, m): + if s[j] == elem: + cnt_s += 1 + s[j] = NULL + for j, other_elem in enumerate(t): + if other_elem == elem: + cnt_t += 1 + t[j] = NULL + if cnt_s != cnt_t: + diff = _Mismatch(cnt_s, cnt_t, elem) + result.append(diff) + + for i, elem in enumerate(t): + if elem is NULL: + continue + cnt_t = 0 + for j in range(i, n): + if t[j] == elem: + cnt_t += 1 + t[j] = NULL + diff = _Mismatch(0, cnt_t, elem) + result.append(diff) + return result + + +def _count_diff_hashable(actual, expected): + "Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ" + # elements must be hashable + s, t = collections.Counter(actual), collections.Counter(expected) + result = [] + for elem, cnt_s in s.items(): + cnt_t = t.get(elem, 0) + if cnt_s != cnt_t: + diff = _Mismatch(cnt_s, cnt_t, elem) + result.append(diff) + for elem, cnt_t in t.items(): + if elem not in s: + diff = _Mismatch(0, cnt_t, elem) + result.append(diff) + return result diff --git a/poetry.lock b/poetry.lock index 74b9c6ca..389d0a55 100644 --- a/poetry.lock +++ b/poetry.lock @@ -641,6 +641,20 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "pytest-order" +version = "1.0.1" +description = "pytest plugin to run your tests in a specific order" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +pytest = [ + {version = ">=5.0", markers = "python_version < \"3.10\""}, + {version = ">=6.2.4", markers = "python_version >= \"3.10\""}, +] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -875,7 +889,7 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "32c785dc385756378e9fcc3e5eafe744e4a35571b6e319f199a690a252796001" +content-hash = "fa15166dd2d341be2c558c0ac27254c4b30721a2b61c80abea2bd84fb91f4674" [metadata.files] aiocron = [ @@ -1665,6 +1679,10 @@ pytest = [ {file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"}, {file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"}, ] +pytest-order = [ + {file = "pytest-order-1.0.1.tar.gz", hash = "sha256:5dd6b929fbd7eaa6d0ee07586f65c623babb0afe72b4843c5f15055d6b3b1b1f"}, + {file = "pytest_order-1.0.1-py3-none-any.whl", hash = "sha256:bbe6e63a8e23741ab3e810d458d1ea7317e797b70f9550512d77d6e9e8fd1bbb"}, +] python-dateutil = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, diff --git a/pyproject.toml b/pyproject.toml index a0fb776c..593794bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nibiru" -version = "0.16.0-beta.1" +version = "0.16.0-beta.1.dev1" description = "Python SDK for interacting with Nibiru." authors = ["Nibiru Chain "] license = "MIT" @@ -46,6 +46,9 @@ pytest = "^7.1.3" black = "^22.10.0" +[tool.poetry.group.dev.dependencies] +pytest-order = "^1.0.1" + [tool.black] line-length = 88 skip-string-normalization = true diff --git a/tests/websocket_test.py b/tests-heavy/websocket_test.py similarity index 82% rename from tests/websocket_test.py rename to tests-heavy/websocket_test.py index e3399e43..70409552 100644 --- a/tests/websocket_test.py +++ b/tests-heavy/websocket_test.py @@ -5,13 +5,13 @@ import nibiru import nibiru.msg -from nibiru import Network, Sdk, Transaction, common +from nibiru import Network, Sdk, Transaction, pytypes from nibiru.event_specs import EventCaptured from nibiru.websocket import EventType, NibiruWebsocket from tests import LOGGER -def test_websocket_listen(val_node: nibiru.Sdk, network: Network): +def test_websocket_listen(sdk_val: nibiru.Sdk, network: Network): """ Open a position and ensure output is correct """ @@ -45,36 +45,36 @@ def test_websocket_listen(val_node: nibiru.Sdk, network: Network): # Open a position from the validator node LOGGER.info("Opening position") - val_node.tx.execute_msgs( + sdk_val.tx.execute_msgs( [ nibiru.msg.MsgOpenPosition( - sender=val_node.address, + sender=sdk_val.address, token_pair=pair, - side=common.Side.BUY, + side=pytypes.Side.BUY, quote_asset_amount=10, leverage=10, base_asset_amount_limit=0, ), nibiru.msg.MsgSend( - from_address=val_node.address, + from_address=sdk_val.address, to_address="nibi1a9s5adwysufv4n5ed2ahs4kaqkaf2x3upm2r9p", # random address coins=nibiru.Coin(amount=10, denom="unibi"), ), ] ) - val_node.tx.execute_msgs( + sdk_val.tx.execute_msgs( nibiru.msg.MsgPostPrice( - oracle=val_node.address, + oracle=sdk_val.address, token0="unibi", token1="unusd", price=10, expiry=datetime.utcnow() + timedelta(hours=1), ), ) - val_node.tx.execute_msgs( + sdk_val.tx.execute_msgs( nibiru.msg.MsgPostPrice( - oracle=val_node.address, + oracle=sdk_val.address, token0="unibi", token1="unusd", price=11, @@ -83,9 +83,9 @@ def test_websocket_listen(val_node: nibiru.Sdk, network: Network): ) LOGGER.info("Closing position") - val_node.tx.execute_msgs( + sdk_val.tx.execute_msgs( nibiru.msg.MsgClosePosition( - sender=val_node.address, + sender=sdk_val.address, token_pair=pair, ) ) @@ -96,7 +96,6 @@ def test_websocket_listen(val_node: nibiru.Sdk, network: Network): nibiru_websocket.queue.put(None) events: List[EventCaptured] = [] - event = 1 while True: event = nibiru_websocket.queue.get() if event is None: @@ -117,7 +116,7 @@ def test_websocket_listen(val_node: nibiru.Sdk, network: Network): assert not missing_events, f"Missing events: {missing_events}" -def test_websocket_tx_fail_queue(val_node: Sdk, network: Network): +def test_websocket_tx_fail_queue(sdk_val: Sdk, network: Network): """ Try executing failing TXs and get errors from tx_fail_queue """ @@ -132,14 +131,14 @@ def test_websocket_tx_fail_queue(val_node: Sdk, network: Network): time.sleep(1) # Send failing closing transaction without simulation - val_node.tx.client.sync_timeout_height() - address = val_node.tx.get_address_info() + sdk_val.tx.client.sync_timeout_height() + address = sdk_val.tx.get_address_info() tx = ( Transaction() .with_messages( [ nibiru.msg.MsgClosePosition( - sender=val_node.address, + sender=sdk_val.address, token_pair="abc:def", ).to_pb() ] @@ -147,9 +146,9 @@ def test_websocket_tx_fail_queue(val_node: Sdk, network: Network): .with_sequence(address.get_sequence()) .with_account_num(address.get_number()) .with_chain_id(network.chain_id) - .with_signer(val_node.tx.priv_key) + .with_signer(sdk_val.tx.priv_key) ) - val_node.tx.execute_tx(tx, 300000) + sdk_val.tx.execute_tx(tx, 300000) time.sleep(3) diff --git a/tests/__init__.py b/tests/__init__.py index 34216105..5b80f686 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,16 +1,29 @@ """Tests package for the Nibiru Python SDK""" -import collections import logging import pprint -from typing import Any, Iterable, List, Optional, Union +from typing import Iterable, List, Union import shutup -from nibiru.utils import init_logger +from nibiru import utils shutup.please() -LOGGER: logging.Logger = init_logger("test-logger") +LOGGER: logging.Logger = utils.init_logger("test-logger") + + +def raises(errs: Union[str, Iterable[str]], err: BaseException): + """Makes sure one of the errors in 'errs' in contained in 'err'. If none of + the given exceptions were raised, this function raises the original exception. + """ + if isinstance(errs, str): + errs = [errs] + else: + errs = list(errs) + errs: List[str] + + err_string = str(err) + assert any([e in err_string for e in errs]), err_string def format_response(resp: Union[dict, list, str]) -> str: @@ -46,7 +59,7 @@ def dict_keys_must_match(dict_: dict, keys: List[str]): - [0, 0, 1] and [0, 1] compare unequal. """ - assert element_counts_are_equal(dict_.keys(), keys) + assert utils.element_counts_are_equal(dict_.keys(), keys) def transaction_must_succeed(tx_output: dict): @@ -59,112 +72,9 @@ def transaction_must_succeed(tx_output: dict): """ assert isinstance(tx_output, dict) - dict_keys_must_match( - tx_output, - [ - "height", - "txhash", - "data", - "rawLog", - "logs", - "gasWanted", - "gasUsed", - "events", - ], - ) + expected_keys = ["height", "txhash", "data", "rawLog", "logs", "gasWanted"] + [ + "gasUsed", + "events", + ] + dict_keys_must_match(tx_output, expected_keys) assert isinstance(tx_output["rawLog"], list) - - -def element_counts_are_equal( - first: Iterable[Any], second: Iterable[Any] -) -> Optional[bool]: - """Asserts that two iterables have the same elements, the same number of - times, without regard to order. - - Args: - first (Iterable[Any]) - second (Iterable[Any]) - - Returns: - Optional[bool]: "passed" status. If this is True, first and second share - the same element counts. If they don't the function will raise an - AssertionError and return 'None'. - """ - first_seq, second_seq = list(first), list(second) - - passed: Union[bool, None] - try: - first = collections.Counter(first_seq) - second = collections.Counter(second_seq) - except TypeError: - # Handle case with unhashable elements - differences = _count_diff_all_purpose(first_seq, second_seq) - else: - if first == second: - passed = True - return passed - differences = _count_diff_hashable(first_seq, second_seq) - - if differences: - standardMsg = "Element counts were not equal:\n" - lines = ["First has %d, Second has %d: %r" % diff for diff in differences] - diffMsg = "\n".join(lines) - msg = "\n".join([standardMsg, diffMsg]) - passed = False - assert passed, msg - - -_Mismatch = collections.namedtuple("Mismatch", "actual expected value") - - -def _count_diff_all_purpose(actual, expected): - "Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ" - # elements need not be hashable - s, t = list(actual), list(expected) - m, n = len(s), len(t) - NULL = object() - result = [] - for i, elem in enumerate(s): - if elem is NULL: - continue - cnt_s = cnt_t = 0 - for j in range(i, m): - if s[j] == elem: - cnt_s += 1 - s[j] = NULL - for j, other_elem in enumerate(t): - if other_elem == elem: - cnt_t += 1 - t[j] = NULL - if cnt_s != cnt_t: - diff = _Mismatch(cnt_s, cnt_t, elem) - result.append(diff) - - for i, elem in enumerate(t): - if elem is NULL: - continue - cnt_t = 0 - for j in range(i, n): - if t[j] == elem: - cnt_t += 1 - t[j] = NULL - diff = _Mismatch(0, cnt_t, elem) - result.append(diff) - return result - - -def _count_diff_hashable(actual, expected): - "Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ" - # elements must be hashable - s, t = collections.Counter(actual), collections.Counter(expected) - result = [] - for elem, cnt_s in s.items(): - cnt_t = t.get(elem, 0) - if cnt_s != cnt_t: - diff = _Mismatch(cnt_s, cnt_t, elem) - result.append(diff) - for elem, cnt_t in t.items(): - if elem not in s: - diff = _Mismatch(0, cnt_t, elem) - result.append(diff) - return result diff --git a/tests/auth_test.py b/tests/auth_test.py index 3b0acce0..a6851a3b 100644 --- a/tests/auth_test.py +++ b/tests/auth_test.py @@ -2,20 +2,20 @@ import tests -def test_query_auth_account(val_node: nibiru.Sdk): +def test_query_auth_account(sdk_val: nibiru.Sdk): tests.LOGGER.debug( - "val_node", + "sdk_val", ) - query_resp: dict = val_node.query.auth.account(val_node.address)["account"] + query_resp: dict = sdk_val.query.auth.account(sdk_val.address)["account"] tests.dict_keys_must_match( query_resp, ['@type', 'address', 'pubKey', 'sequence', 'accountNumber'] ) -def test_query_auth_accounts(val_node: nibiru.Sdk): - query_resp: dict = val_node.query.auth.accounts() +def test_query_auth_accounts(sdk_val: nibiru.Sdk): + query_resp: dict = sdk_val.query.auth.accounts() for account in query_resp["accounts"]: diff --git a/tests/bank_test.py b/tests/bank_test.py index cf224051..eaadb409 100644 --- a/tests/bank_test.py +++ b/tests/bank_test.py @@ -8,20 +8,20 @@ PRECISION = 6 -def test_send_multiple_msgs(val_node: nibiru.Sdk, agent: nibiru.Sdk): +def test_send_multiple_msgs(sdk_val: nibiru.Sdk, sdk_agent: nibiru.Sdk): """Tests the transfer of funds for a transaction with a multiple 'MsgSend' messages.""" - tx_output = val_node.tx.execute_msgs( + tx_output = sdk_val.tx.execute_msgs( [ nibiru.msg.MsgSend( - val_node.address, - agent.address, - [Coin(10000, "unibi"), Coin(100, "unusd")], + sdk_val.address, + sdk_agent.address, + [Coin(7, "unibi"), Coin(70, "unusd")], ), nibiru.msg.MsgSend( - val_node.address, - agent.address, - [Coin(10000, "unibi"), Coin(100, "unusd")], + sdk_val.address, + sdk_agent.address, + [Coin(15, "unibi"), Coin(23, "unusd")], ), ] ) @@ -32,15 +32,15 @@ def test_send_multiple_msgs(val_node: nibiru.Sdk, agent: nibiru.Sdk): tests.transaction_must_succeed(tx_output) -def test_send_single_msg(val_node: nibiru.Sdk, agent: nibiru.Sdk): +def test_send_single_msg(sdk_val: nibiru.Sdk, sdk_agent: nibiru.Sdk): """Tests the transfer of funds for a transaction with a single 'MsgSend' message.""" - tx_output = val_node.tx.execute_msgs( + tx_output = sdk_val.tx.execute_msgs( [ nibiru.msg.MsgSend( - val_node.address, - agent.address, - [Coin(10000, "unibi"), Coin(100, "unusd")], + sdk_val.address, + sdk_agent.address, + [Coin(10, "unibi"), Coin(10, "unusd")], ), ] ) diff --git a/tests/chain_info_test.py b/tests/chain_info_test.py index 2096a049..52104a79 100644 --- a/tests/chain_info_test.py +++ b/tests/chain_info_test.py @@ -19,19 +19,19 @@ def test_genesis_block_ping(network: Network): assert all([key in query_resp.keys() for key in ["jsonrpc", "id", "result"]]) -def test_get_chain_id(val_node: Sdk): - assert val_node.network.chain_id == val_node.query.get_chain_id() +def test_get_chain_id(sdk_val: Sdk): + assert sdk_val.network.chain_id == sdk_val.query.get_chain_id() -def test_wait_next_block(val_node: Sdk): - current_block_height = val_node.query.get_latest_block().block.header.height - val_node.query.wait_for_next_block() - new_block_height = val_node.query.get_latest_block().block.header.height +def test_wait_next_block(sdk_val: Sdk): + current_block_height = sdk_val.query.get_latest_block().block.header.height + sdk_val.query.wait_for_next_block() + new_block_height = sdk_val.query.get_latest_block().block.header.height assert new_block_height > current_block_height -def test_version_works(val_node: Sdk): +def test_version_works(sdk_val: Sdk): tests = [ {"should_fail": False, "versions": ["0.3.2", "0.3.2"]}, {"should_fail": True, "versions": ["0.3.2", "0.3.4"]}, @@ -46,13 +46,13 @@ def test_version_works(val_node: Sdk): for test in tests: if test["should_fail"]: with pytest.raises(AssertionError, match="Version error"): - val_node.query.assert_compatible_versions(*test["versions"]) + sdk_val.query.assert_compatible_versions(*test["versions"]) else: - val_node.query.assert_compatible_versions(*test["versions"]) + sdk_val.query.assert_compatible_versions(*test["versions"]) -def test_query_perp_params(val_node: Sdk): - params: Dict[str, Union[float, str]] = val_node.query.perp.params() +def test_query_perp_params(sdk_val: Sdk): + params: Dict[str, Union[float, str]] = sdk_val.query.perp.params() perp_param_names: List[str] = [ "ecosystemFundFeeRatio", "feePoolFeeRatio", @@ -63,14 +63,14 @@ def test_query_perp_params(val_node: Sdk): assert all([(param_name in params) for param_name in perp_param_names]) -def test_block_getters(agent: Sdk): +def test_block_getters(sdk_agent: Sdk): """Tests queries from the Tendemint gRPC channel - GetBlockByHeight - GetLatestBlock """ - block_by_height_resp = agent.query.get_block_by_height(2) - latest_block_resp = agent.query.get_latest_block() + block_by_height_resp = sdk_agent.query.get_block_by_height(2) + latest_block_resp = sdk_agent.query.get_latest_block() block_id_fields: List[str] = ["hash", "part_set_header"] block_fields: List[str] = ["data", "evidence", "header", "last_commit"] for block_resp in [block_by_height_resp, latest_block_resp]: @@ -82,12 +82,12 @@ def test_block_getters(agent: Sdk): ), "missing attributes on the 'block' field" -def test_blocks_getters(agent: Sdk): +def test_blocks_getters(sdk_agent: Sdk): """Tests queries from the Tendemint gRPC channel - GetBlocksByHeight """ - block_by_height_resp = agent.query.get_blocks_by_height(2, 5) + block_by_height_resp = sdk_agent.query.get_blocks_by_height(2, 5) block_id_fields: List[str] = ["hash", "part_set_header"] block_fields: List[str] = ["data", "evidence", "header", "last_commit"] for block_resp in block_by_height_resp: @@ -99,9 +99,9 @@ def test_blocks_getters(agent: Sdk): ), "missing attributes on the 'block' field" -def test_query(val_node: Sdk): +def test_query(sdk_val: Sdk): """ Open a position and ensure output is correct """ - assert isinstance(val_node.query.get_latest_block_height(), int) - assert isinstance(val_node.query.get_version(), str) + assert isinstance(sdk_val.query.get_latest_block_height(), int) + assert isinstance(sdk_val.query.get_version(), str) diff --git a/tests/conftest.py b/tests/conftest.py index d7eb198f..b7a7ae64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,42 +7,64 @@ - docs reference: https://docs.pytest.org/en/6.2.x/fixture.html Fixtures available: -- network -- val_node -- agent_node +- sdk_val +- sdk_agent +- sdk_oracle """ import os -from typing import List +from typing import Any, Dict, List, Optional +import dotenv import pytest -from dotenv import load_dotenv from nibiru import Network, Sdk -from nibiru.common import TxConfig, TxType +from nibiru.pytypes import TxConfig, TxType -EXPECTED_ENV_VARS: List[str] = [ - "LCD_ENDPOINT", - "GRPC_ENDPOINT", - "TENDERMINT_RPC_ENDPOINT", - "WEBSOCKET_ENDPOINT", - "CHAIN_ID", - "VALIDATOR_MNEMONIC", - "ORACLE_MNEMONIC", - "NETWORK_INSECURE", -] +PYTEST_GLOBALS_REQUIRED: Dict[str, str] = dict( + VALIDATOR_MNEMONIC="", + ORACLE_MNEMONIC="", +) +PYTEST_GLOBALS_OPTIONAL: Dict[str, Any] = dict( + use_localnet=False, + LCD_ENDPOINT="", + GRPC_ENDPOINT="", + TENDERMINT_RPC_ENDPOINT="", + WEBSOCKET_ENDPOINT="", + CHAIN_ID="", +) +PYTEST_GLOBALS: Dict[str, Any] = { + **PYTEST_GLOBALS_REQUIRED, # combines dictionaries + **PYTEST_GLOBALS_OPTIONAL, +} def pytest_configure(config): - load_dotenv() - - for env_var in EXPECTED_ENV_VARS: - val = os.getenv(env_var) - if not val: - raise ValueError(f"Environment variable {env_var} is missing!") - setattr(pytest, env_var, val) # pytest. = val - - # NETWORK_INSECURE must be a boolean - pytest.NETWORK_INSECURE = os.getenv("NETWORK_INSECURE") != "false" + dotenv.load_dotenv() + + EXPECTED_ENV_VARS: List[str] = list(PYTEST_GLOBALS.keys()) + + def set_pytest_global(name: str, value: Any): + """Adds environment variables to the 'pytest' object and the 'PYTEST_GLOBALS' + dictionary so that a central point of truth on what variables are set + can be accessed from within tests. + """ + setattr(pytest, name, value) # pytest. = val + PYTEST_GLOBALS[name] = value + + use_localnet: Optional[str] = os.getenv("USE_LOCALNET") + if use_localnet is not None: + if use_localnet.lower() == "true": + set_pytest_global("use_localnet", True) + if not use_localnet: + EXPECTED_ENV_VARS = [key for key in PYTEST_GLOBALS_REQUIRED] + set_pytest_global("use_localnet", False) + + # Set the expected environment variables. Raise a value error if one is missing + for env_var_name in EXPECTED_ENV_VARS: + env_var_value = os.getenv(env_var_name) + if not env_var_value: + raise ValueError(f"Environment variable {env_var_name} is missing!") + set_pytest_global(env_var_name, env_var_value) @pytest.fixture @@ -58,31 +80,41 @@ def network() -> Network: env="unit_test", ) """ + if PYTEST_GLOBALS["use_localnet"]: + return Network.customnet() return Network.devnet(2) TX_CONFIG: TxConfig = TxConfig( - tx_type=TxType.BLOCK, gas_multiplier=3, gas_price=1, gas_wanted=250000 + tx_type=TxType.BLOCK, + gas_multiplier=1.25, + gas_price=0.25, ) @pytest.fixture -def val_node(network: Network) -> Sdk: +def sdk_val(network: Network) -> Sdk: tx_config = TX_CONFIG - network_insecure: bool = not ("https" in network.tendermint_rpc_endpoint) - return ( Sdk.authorize(pytest.VALIDATOR_MNEMONIC) .with_config(tx_config) - .with_network(network, network_insecure) + .with_network(network) ) @pytest.fixture -def agent(network: Network) -> Sdk: +def sdk_agent(network: Network) -> Sdk: tx_config = TX_CONFIG - network_insecure: bool = not ("https" in network.tendermint_rpc_endpoint) - agent = ( - Sdk.authorize().with_config(tx_config).with_network(network, network_insecure) - ) + agent = Sdk.authorize().with_config(tx_config).with_network(network) return agent + + +# address: nibi10hj3gq54uxd9l5d6a7sn4dcvhd0l3wdgt2zvyp +@pytest.fixture +def sdk_oracle(network: Network) -> Sdk: + tx_config = TX_CONFIG + return ( + Sdk.authorize(pytest.ORACLE_MNEMONIC) + .with_config(tx_config) + .with_network(network) + ) diff --git a/tests/dex_test.py b/tests/dex_test.py index d4bd8003..64d6cf92 100644 --- a/tests/dex_test.py +++ b/tests/dex_test.py @@ -1,23 +1,36 @@ # perp_test.py +from typing import Dict, List + +import pytest + import nibiru import nibiru.msg -from nibiru import Coin, PoolAsset -from nibiru.common import PoolType +import tests +from nibiru import Coin, PoolAsset, utils from nibiru.exceptions import SimulationError +from nibiru.pytypes import PoolType from tests import transaction_must_succeed PRECISION = 6 -def test_dex(val_node: nibiru.Sdk, agent: nibiru.Sdk): +class DexErrors: + same_denom = "a pool with the same denoms already exists" + insufficient_funds = "smaller than 1000000000unibi: insufficient funds" + swap_low_unusd_in_pool = "tokenIn (unusd) must be higher to perform a swap" + no_pool_shares = "0nibiru/pool/" + + +def test_dex_create_pool(sdk_val: nibiru.Sdk): """ Test the workflow for pools """ + try: - tx_output = val_node.tx.execute_msgs( + tx_output = sdk_val.tx.execute_msgs( nibiru.msg.MsgCreatePool( - creator=val_node.address, + creator=sdk_val.address, swap_fee=0.01, exit_fee=0.02, assets=[ @@ -31,17 +44,21 @@ def test_dex(val_node: nibiru.Sdk, agent: nibiru.Sdk): transaction_must_succeed(tx_output) except SimulationError as simulation_error: - assert "a pool with the same denoms already exists" in str(simulation_error) + tests.raises( + [DexErrors.same_denom, DexErrors.insufficient_funds], simulation_error + ) + """ + # # TODO fix: need usdc on-chain to do this try: - tx_output = val_node.tx.execute_msgs( + tx_output = sdk_val.tx.execute_msgs( nibiru.msg.MsgCreatePool( - creator=val_node.address, + creator=sdk_val.address, swap_fee=0.01, exit_fee=0.02, assets=[ - PoolAsset(token=Coin(100, "uusdc"), weight=50), PoolAsset(token=Coin(1000, "unusd"), weight=50), + PoolAsset(token=Coin(100, "uusdc"), weight=50), ], pool_type=PoolType.STABLESWAP, a=10, @@ -49,12 +66,21 @@ def test_dex(val_node: nibiru.Sdk, agent: nibiru.Sdk): ) transaction_must_succeed(tx_output) except SimulationError as simulation_error: - assert "a pool with the same denoms already exists" in str(simulation_error) + assert has_reasonable_err(simulation_error), simulation_error + """ - # Assert pool are there. - pools = val_node.query.dex.pools() - pool_ids = {} - for pool_assets in ["unibi:unusd", "uusdc:unusd"]: + +@pytest.fixture +def pools(sdk_val: nibiru.Sdk) -> List[dict]: + return sdk_val.query.dex.pools() + + +@pytest.fixture +def pool_ids(pools: List[dict]) -> Dict[str, int]: + pool_ids: Dict[str, int] = {} + # for pool_assets in ["unibi:unusd", "unusd:uusdc"]: + # # TODO fix: need usdc on-chain to do this + for pool_assets in ["unibi:unusd"]: pool_assets_expected = set(pool_assets.split(":")) any( @@ -70,7 +96,7 @@ def test_dex(val_node: nibiru.Sdk, agent: nibiru.Sdk): ] ) - pool_ids[pool_assets] = int( + pool_id = int( [ pool["id"] for pool in pools @@ -83,57 +109,90 @@ def test_dex(val_node: nibiru.Sdk, agent: nibiru.Sdk): ) ][0] ) + pool_ids[pool_assets] = pool_id + return pool_ids + + +@pytest.mark.order(after="test_dex_create_pool") +def test_dex_query_pools(pools: List[dict]): + if not pools: + return + + pool = pools[0] + keys = ["id", "address", "poolParams", "poolAssets", "totalWeight", "totalShares"] + assert isinstance(pool, dict) + utils.dict_keys_must_match(pool, keys) + + +@pytest.mark.order(after="test_dex_query_pools") +def test_dex_join_pool(sdk_val: nibiru.Sdk, pool_ids: Dict[str, int]): + try: + tx_output = sdk_val.tx.execute_msgs( + [ + nibiru.msg.MsgJoinPool( + sender=sdk_val.address, + pool_id=pool_ids["unibi:unusd"], + tokens=[Coin(1000, "unibi"), Coin(100, "unusd")], + ), + ] + ) + transaction_must_succeed(tx_output) + except BaseException as err: + tests.raises(DexErrors.no_pool_shares, err) + - # Join/swap/exit pool - tx_output = val_node.tx.execute_msgs( - nibiru.msg.MsgSend( - val_node.address, - agent.address, - [Coin(10000, "unibi"), Coin(200, "unusd"), Coin(200, "uusdc")], +@pytest.mark.order(after="test_dex_join_pool") +def test_dex_swap(sdk_val: nibiru.Sdk, pool_ids: Dict[str, int]): + try: + tx_output = sdk_val.tx.execute_msgs( + [ + # # TODO fix: need usdc on-chain to do this + # nibiru.msg.MsgJoinPool( + # sender=sdk_agent.address, + # pool_id=pool_ids["unusd:uusdc"], + # tokens=[Coin(100, "uusdc"), Coin(100, "unusd")], + # ), + # # TODO fix: need usdc on-chain to do this + # nibiru.msg.MsgSwapAssets( + # sender=sdk_agent.address, + # pool_id=pool_ids["unusd:uusdc"], + # token_in=Coin(100, "uusdc"), + # token_out_denom="unusd", + # ), + nibiru.msg.MsgSwapAssets( + sender=sdk_val.address, + pool_id=pool_ids["unibi:unusd"], + token_in=Coin(100, "unusd"), + token_out_denom="unibi", + ), + ] + ) + transaction_must_succeed(tx_output) + except BaseException as err: + tests.raises(DexErrors.swap_low_unusd_in_pool, err) + + +@pytest.mark.order(after="test_dex_swap") +def test_dex_exit_pool(sdk_val: nibiru.Sdk): + balance = sdk_val.query.get_bank_balances(sdk_val.address)["balances"] + + pool_tokens: List[str] = [ + pool_token for pool_token in balance if "nibiru/pool" in pool_token + ] + if pool_tokens: + tx_output = sdk_val.tx.execute_msgs( + [ + nibiru.msg.MsgExitPool( + sender=sdk_val.address, + pool_id=int(pool_token["denom"].split("/")[-1]), + pool_shares=Coin(pool_token["amount"], pool_token["denom"]), + ) + for pool_token in pool_tokens + ] + ) + transaction_must_succeed(tx_output) + else: + tests.LOGGER.info( + "skipped test for 'nibid tx dex exit-pool' because\n" + + f"{sdk_val.address} did not have LP shares" ) - ) - transaction_must_succeed(tx_output) - - pools = val_node.query.dex.pools() - - tx_output = agent.tx.execute_msgs( - [ - nibiru.msg.MsgJoinPool( - sender=agent.address, - pool_id=pool_ids["unibi:unusd"], - tokens=[Coin(1000, "unibi"), Coin(100, "unusd")], - ), - nibiru.msg.MsgJoinPool( - sender=agent.address, - pool_id=pool_ids["uusdc:unusd"], - tokens=[Coin(100, "uusdc"), Coin(100, "unusd")], - ), - nibiru.msg.MsgSwapAssets( - sender=agent.address, - pool_id=pool_ids["uusdc:unusd"], - token_in=Coin(100, "uusdc"), - token_out_denom="unusd", - ), - nibiru.msg.MsgSwapAssets( - sender=agent.address, - pool_id=pool_ids["unibi:unusd"], - token_in=Coin(100, "unibi"), - token_out_denom="unusd", - ), - ] - ) - transaction_must_succeed(tx_output) - - balance = agent.query.get_bank_balances(agent.address)["balances"] - - tx_output = agent.tx.execute_msgs( - [ - nibiru.msg.MsgExitPool( - sender=agent.address, - pool_id=int(pool_token["denom"].split("/")[-1]), - pool_shares=Coin(pool_token["amount"], pool_token["denom"]), - ) - for pool_token in balance - if "nibiru/pool" in pool_token["denom"] - ] - ) diff --git a/tests/epoch_test.py b/tests/epoch_test.py index 327a115b..e5eb2006 100644 --- a/tests/epoch_test.py +++ b/tests/epoch_test.py @@ -2,13 +2,13 @@ import tests -def test_query_current_epoch(val_node: nibiru.Sdk): - query_resp: dict = val_node.query.epoch.current_epoch("15 min") +def test_query_current_epoch(sdk_val: nibiru.Sdk): + query_resp: dict = sdk_val.query.epoch.current_epoch("15 min") assert query_resp["currentEpoch"] > 0 -def test_query_epoch_info(val_node: nibiru.Sdk): - query_resp: dict = val_node.query.epoch.epoch_infos() +def test_query_epoch_info(sdk_val: nibiru.Sdk): + query_resp: dict = sdk_val.query.epoch.epoch_infos() print(query_resp) assert len(query_resp["epochs"]) > 0 diff --git a/tests/event_test.py b/tests/event_test.py new file mode 100644 index 00000000..2318d9fb --- /dev/null +++ b/tests/event_test.py @@ -0,0 +1,92 @@ +from typing import List + +import pytest + +from nibiru import pytypes + + +class TestEvent: + @pytest.fixture + def raw_events(self) -> List[pytypes.RawEvent]: + return [ + { + 'attributes': [ + { + 'key': 'recipient', + 'value': 'nibi1uvu52rxwqj5ndmm59y6atvx33mru9xrz6sqekr', + }, + { + 'key': 'sender', + 'value': 'nibi1zaavvzxez0elundtn32qnk9lkm8kmcsz44g7xl', + }, + {'key': 'amount', 'value': '7unibi,70unusd'}, + ], + 'type': 'transfer', + }, + { + 'attributes': [ + {'key': 'action', 'value': 'post_price'}, + {'key': 'module', 'value': 'pricefeed'}, + { + 'key': 'sender', + 'value': 'nibi10hj3gq54uxd9l5d6a7sn4dcvhd0l3wdgt2zvyp', + }, + ], + 'type': 'message', + }, + { + 'attributes': [ + {'key': 'expiry', 'value': '"2022-12-09T07:58:49.559512Z"'}, + { + 'key': 'oracle', + 'value': '"nibi10hj3gq54uxd9l5d6a7sn4dcvhd0l3wdgt2zvyp"', + }, + {'key': 'pair_id', 'value': '"ueth:unusd"'}, + {'key': 'pair_price', 'value': '"1800.000000000000000000"'}, + ], + 'type': 'nibiru.pricefeed.v1.EventOracleUpdatePrice', + }, + ] + + def test_parse_attributes(self, raw_events: List[pytypes.RawEvent]): + raw_event = raw_events[0] + assert "attributes" in raw_event + raw_attributes: list[dict[str, str]] = raw_event['attributes'] + attrs: dict[str, str] = pytypes.Event.parse_attributes(raw_attributes) + assert attrs["recipient"] == "nibi1uvu52rxwqj5ndmm59y6atvx33mru9xrz6sqekr" + assert attrs["sender"] == "nibi1zaavvzxez0elundtn32qnk9lkm8kmcsz44g7xl" + assert attrs["amount"] == "7unibi,70unusd" + + raw_event = raw_events[1] + assert "attributes" in raw_event + raw_attributes: list[dict[str, str]] = raw_event['attributes'] + attrs: dict[str, str] = pytypes.Event.parse_attributes(raw_attributes) + assert attrs["action"] == "post_price" + assert attrs["module"] == "pricefeed" + oracle = "nibi10hj3gq54uxd9l5d6a7sn4dcvhd0l3wdgt2zvyp" + assert attrs["sender"] == oracle + + raw_event = raw_events[2] + assert "attributes" in raw_event + raw_attributes: list[dict[str, str]] = raw_event['attributes'] + attrs: dict[str, str] = pytypes.Event.parse_attributes(raw_attributes) + assert attrs["expiry"] == '"2022-12-09T07:58:49.559512Z"' + assert attrs["oracle"] == f'"{oracle}"' + assert attrs["pair_id"] == '"ueth:unusd"' + assert attrs["pair_price"] == '"1800.000000000000000000"' + + def test_new_event(self, raw_events: List[pytypes.RawEvent]): + event = pytypes.Event(raw_events[0]) + assert event.type == "transfer" + for attr in ["recipient", "sender", "amount"]: + assert attr in event.attrs + + event = pytypes.Event(raw_events[1]) + assert event.type == "message" + for attr in ["action", "module", "sender"]: + assert attr in event.attrs + + event = pytypes.Event(raw_events[2]) + assert event.type == "nibiru.pricefeed.v1.EventOracleUpdatePrice" + for attr in ["expiry", "oracle", "pair_id", "pair_price"]: + assert attr in event.attrs diff --git a/tests/perp_test.py b/tests/perp_test.py index 84c73c71..c4bebcb7 100644 --- a/tests/perp_test.py +++ b/tests/perp_test.py @@ -1,112 +1,153 @@ # perp_test.py +from typing import List + import pytest import nibiru import nibiru.msg import tests -from nibiru import Coin, common +from nibiru import pytypes as pt from nibiru.exceptions import QueryError -from tests import LOGGER, dict_keys_must_match, transaction_must_succeed +from tests import dict_keys_must_match, transaction_must_succeed PRECISION = 6 +PAIR = "ubtc:unusd" + + +class ERRORS: + position_not_found = "collections: not found: 'nibiru.perp.v1.Position'" + bad_debt = "bad debt" + underwater_position = "underwater position" + + +def test_open_position(sdk_val: nibiru.Sdk): + try: + tests.LOGGER.info("nibid tx perp open-position") + tx_output: pt.RawTxResp = sdk_val.tx.execute_msgs( + nibiru.msg.MsgOpenPosition( + sender=sdk_val.address, + token_pair=PAIR, + side=pt.Side.SELL, + quote_asset_amount=10, + leverage=10, + base_asset_amount_limit=0, + ) + ) + tests.LOGGER.info( + f"nibid tx perp open-position: {tests.format_response(tx_output)}" + ) + transaction_must_succeed(tx_output) + + tx_resp = pt.TxResp.from_raw(pt.RawTxResp(tx_output)) + assert "/nibiru.perp.v1.MsgOpenPosition" in tx_resp.rawLog[0].msgs + events_for_msg: List[str] = [ + "nibiru.perp.v1.PositionChangedEvent", + "nibiru.vpool.v1.SwapQuoteForBaseEvent", + "nibiru.vpool.v1.MarkPriceChangedEvent", + "transfer", + ] + assert all( + [msg_event in tx_resp.rawLog[0].event_types for msg_event in events_for_msg] + ) + except BaseException as err: + tests.raises(ERRORS.bad_debt, err) + + +@pytest.mark.order(after="test_open_position") +def test_perp_query_position(sdk_val: nibiru.Sdk): + try: + # Trader position must be a dict with specific keys + position_res = sdk_val.query.perp.position( + trader=sdk_val.address, token_pair=PAIR + ) + dict_keys_must_match( + position_res, + [ + "block_number", + "margin_ratio_index", + "margin_ratio_mark", + "position", + "position_notional", + "unrealized_pnl", + ], + ) + tests.LOGGER.info( + f"nibid query perp trader-position: \n{tests.format_response(position_res)}" + ) + + assert position_res["margin_ratio_mark"] + position = position_res["position"] + assert position["margin"] + assert position["open_notional"] + assert position["size"] + except BaseException as err: + tests.raises(ERRORS.position_not_found, err) + + +@pytest.mark.order(after="test_perp_query_position") +def test_perp_add_margin(sdk_val: nibiru.Sdk): + try: + # Transaction add_margin must succeed + tx_output = sdk_val.tx.execute_msgs( + nibiru.msg.MsgAddMargin( + sender=sdk_val.address, + token_pair=PAIR, + margin=pt.Coin(10, "unusd"), + ), + ) + tests.LOGGER.info( + f"nibid tx perp add-margin: \n{tests.format_response(tx_output)}" + ) + except BaseException as err: + tests.raises(ERRORS.bad_debt, err) + + # TODO test: verify the margin changes using the events + -def test_open_close_position(val_node: nibiru.Sdk, agent: nibiru.Sdk): +@pytest.mark.order(after="test_perp_add_margin") +def test_perp_remove_margin(sdk_val: nibiru.Sdk): + try: + tx_output = sdk_val.tx.execute_msgs( + nibiru.msg.MsgRemoveMargin( + sender=sdk_val.address, + token_pair=PAIR, + margin=pt.Coin(5, "unusd"), + ) + ) + tests.LOGGER.info( + f"nibid tx perp remove-margin: \n{tests.format_response(tx_output)}" + ) + transaction_must_succeed(tx_output) + # TODO test: verify the margin changes using the events + except BaseException as err: + tests.raises(ERRORS.bad_debt, err) + + +@pytest.mark.order(after="test_perp_remove_margin") +def test_perp_close_posititon(sdk_val: nibiru.Sdk): """ Open a position and ensure output is correct """ - pair = "ubtc:unusd" - # Funding agent - val_node.tx.execute_msgs( - nibiru.msg.MsgSend( - val_node.address, agent.address, [Coin(10000, "unibi"), Coin(100, "unusd")] + try: + # Transaction close_position must succeed + tx_output = sdk_val.tx.execute_msgs( + nibiru.msg.MsgClosePosition(sender=sdk_val.address, token_pair=PAIR) ) - ) - - # Exception must be raised when requesting not existing position - with pytest.raises(QueryError, match="not found: 'nibiru.perp.v1.Position'"): - agent.query.perp.position(trader=agent.address, token_pair=pair) - - # Transaction open_position must succeed - tx_output: dict = agent.tx.execute_msgs( - nibiru.msg.MsgOpenPosition( - sender=agent.address, - token_pair=pair, - side=common.Side.BUY, - quote_asset_amount=10, - leverage=10, - base_asset_amount_limit=0, + tests.LOGGER.info( + f"nibid tx perp close-position: \n{tests.format_response(tx_output)}" ) - ) - LOGGER.info(f"nibid tx perp open-position: {tests.format_response(tx_output)}") - transaction_must_succeed(tx_output) - - # Trader position must be a dict with specific keys - position_res = agent.query.perp.position(trader=agent.address, token_pair=pair) - dict_keys_must_match( - position_res, - [ - "block_number", - "margin_ratio_index", - "margin_ratio_mark", - "position", - "position_notional", - "unrealized_pnl", - ], - ) - LOGGER.info( - f"nibid query perp trader-position: \n{tests.format_response(position_res)}" - ) - # Margin ratio must be ~10% - assert position_res["margin_ratio_mark"] == pytest.approx(0.1, PRECISION) - - position = position_res["position"] - assert position["margin"] == 10.0 - assert position["open_notional"] == 100.0 - assert position["size"] == pytest.approx(0.005, PRECISION) - - # Transaction add_margin must succeed - tx_output = agent.tx.execute_msgs( - nibiru.msg.MsgAddMargin( - sender=agent.address, - token_pair=pair, - margin=Coin(10, "unusd"), - ) - ) - LOGGER.info(f"nibid tx perp add-margin: \n{tests.format_response(tx_output)}") - transaction_must_succeed(tx_output) - - # Margin must increase. 10 + 10 = 20 - position = agent.query.perp.position(trader=agent.address, token_pair=pair)[ - "position" - ] - assert position["margin"] == 20.0 - - # Transaction remove_margin must succeed - tx_output = agent.tx.execute_msgs( - nibiru.msg.MsgRemoveMargin( - sender=agent.address, - token_pair=pair, - margin=common.Coin(5, "unusd"), - ) - ) - LOGGER.info(f"nibid tx perp remove-margin: \n{tests.format_response(tx_output)}") - transaction_must_succeed(tx_output) - - # Margin must decrease. 20 - 5 = 15 - position = agent.query.perp.position(trader=agent.address, token_pair=pair)[ - "position" - ] - assert position["margin"] == 15.0 - - # Transaction close_position must succeed - tx_output = agent.tx.execute_msgs( - nibiru.msg.MsgClosePosition(sender=agent.address, token_pair=pair) - ) - LOGGER.info(f"nibid tx perp close-position: \n{tests.format_response(tx_output)}") - transaction_must_succeed(tx_output) - - # Exception must be raised when querying closed position - with pytest.raises(QueryError, match="not found: 'nibiru.perp.v1.Position'"): - agent.query.perp.position(trader=agent.address, token_pair=pair) + transaction_must_succeed(tx_output) + + # Querying the position should raise an exception if it closed successfully + with pytest.raises( + (QueryError, BaseException), match=ERRORS.position_not_found + ): + sdk_val.query.perp.position(trader=sdk_val.address, token_pair=PAIR) + except BaseException as err: + expected_errors: List[str] = [ + ERRORS.position_not_found, + ERRORS.underwater_position, + ] + tests.raises(expected_errors, err) diff --git a/tests/pricefeed_test.py b/tests/pricefeed_test.py index 4e9da08b..532894ab 100644 --- a/tests/pricefeed_test.py +++ b/tests/pricefeed_test.py @@ -1,6 +1,6 @@ # pricefeed_test.py from datetime import datetime, timedelta -from typing import List, Optional +from typing import Dict, List, Optional import pytest from nibiru_proto.proto.cosmos.base.abci.v1beta1.abci_pb2 import TxResponse @@ -10,11 +10,11 @@ from nibiru.msg import MsgPostPrice from tests import dict_keys_must_match, transaction_must_succeed -WHITELISTED_ORACLES: List[str] = [ - "nibi1zaavvzxez0elundtn32qnk9lkm8kmcsz44g7xl", - "nibi15cdcxznuwpuk5hw7t678wpyesy78kwy00qcesa", - "nibi1qqx5reauy4glpskmppy88pz25qp2py5yxvpxdt", -] +WHITELISTED_ORACLES: Dict[str, str] = { + "nibi1hk04vteklhmtwe0zpt7023p5zcgu49e5v3atyp": "CoinGecko oracle", + "nibi10hj3gq54uxd9l5d6a7sn4dcvhd0l3wdgt2zvyp": "CoinMarketCap oracle", + "nibi1r8gjajmlp9tkff0759rmujv568pa7q6v7u4m3z": "Binance oracle", +} def post_price_test_tx( @@ -25,18 +25,18 @@ def post_price_test_tx( tests.LOGGER.info(f"sending 'nibid tx post price' from {from_oracle}") msg = MsgPostPrice( oracle=from_oracle, - token0="unibi", + token0="ueth", token1="unusd", - price=10, + price=1800, expiry=datetime.utcnow() + timedelta(hours=1), ) return sdk.tx.execute_msgs(msg) -def test_post_price_unwhitelisted(agent: nibiru.Sdk): +def test_post_price_unwhitelisted(sdk_oracle: nibiru.Sdk): tests.LOGGER.info("'test_post_price_unwhitelisted' - should error") unwhitested_address = "nibi1pzd5e402eld9kcc3h78tmfrm5rpzlzk6hnxkvu" - queryResp = agent.query.pricefeed.oracles("unibi:unusd") + queryResp = sdk_oracle.query.pricefeed.oracles("ueth:unusd") assert unwhitested_address not in queryResp["oracles"] # TODO tests.LOGGER.info(f"oracle address not whitelisted: {unwhitested_address}") @@ -44,80 +44,90 @@ def test_post_price_unwhitelisted(agent: nibiru.Sdk): with pytest.raises( nibiru.exceptions.SimulationError, match="unknown address" ) as err: - tx_output = post_price_test_tx(sdk=agent, from_oracle=unwhitested_address) + tx_output = post_price_test_tx(sdk=sdk_oracle, from_oracle=unwhitested_address) err_msg = str(err) assert transaction_must_succeed(tx_output) is None, err_msg -def test_grpc_error(val_node: nibiru.Sdk): - # Market unibi:unusd must be in the list of pricefeed markets - markets_output = val_node.query.pricefeed.markets() +def test_query_markets(sdk_oracle: nibiru.Sdk): + # Market ueth:unusd must be in the list of pricefeed markets + markets_output = sdk_oracle.query.pricefeed.markets() assert isinstance(markets_output, dict) assert any( - [market["pair_id"] == "unibi:unusd" for market in markets_output["markets"]] + [market["pair_id"] == "ueth:unusd" for market in markets_output["markets"]] ) - # Oracle must be in the list of unibi:unusd market oracles - unibi_unusd_market = next( + # Oracle must be in the list of ueth:unusd market oracles + ueth_unusd_market = next( market for market in markets_output["markets"] - if market["pair_id"] == "unibi:unusd" + if market["pair_id"] == "ueth:unusd" ) - assert val_node.address in unibi_unusd_market["oracles"] - - # Transaction post_price in the past must raise proper error - with pytest.raises(nibiru.exceptions.SimulationError, match="Price is expired"): - _ = val_node.tx.execute_msgs( - msgs=MsgPostPrice( - val_node.address, - token0="unibi", - token1="unusd", - price=10, - expiry=datetime.utcnow() - timedelta(hours=1), # Price expired + assert sdk_oracle.address in ueth_unusd_market["oracles"] + + +@pytest.mark.order(after="test_query_markets") +class TestPostPrice: + def test_post_in_the_past(self, sdk_oracle: nibiru.Sdk): + # Transaction post_price in the past must raise proper error + with pytest.raises(nibiru.exceptions.SimulationError, match="Price is expired"): + _ = sdk_oracle.tx.execute_msgs( + msgs=MsgPostPrice( + sdk_oracle.address, + token0="ueth", + token1="unusd", + price=1800, + expiry=datetime.utcnow() - timedelta(hours=1), # Price expired + ) ) - ) + def test_post_prices(self, sdk_oracle: nibiru.Sdk): -def test_post_prices(val_node: nibiru.Sdk): + # Market ueth:unusd must be in the list of pricefeed markets + markets_output = sdk_oracle.query.pricefeed.markets() + assert isinstance(markets_output, dict) + assert any( + [market["pair_id"] == "ueth:unusd" for market in markets_output["markets"]] + ) - # Market unibi:unusd must be in the list of pricefeed markets - markets_output = val_node.query.pricefeed.markets() - assert isinstance(markets_output, dict) - assert any( - [market["pair_id"] == "unibi:unusd" for market in markets_output["markets"]] - ) + tests.LOGGER.info("Oracle must be in the list of ueth:unusd market oracles") + ueth_unusd_market = next( + market + for market in markets_output["markets"] + if market["pair_id"] == "ueth:unusd" + ) + assert sdk_oracle.address in ueth_unusd_market["oracles"] - tests.LOGGER.info("Oracle must be in the list of unibi:unusd market oracles") - unibi_unusd_market = next( - market - for market in markets_output["markets"] - if market["pair_id"] == "unibi:unusd" - ) - assert val_node.address in unibi_unusd_market["oracles"] + tests.LOGGER.info("Transaction post_price must succeed") + tx_output = post_price_test_tx(sdk=sdk_oracle) + tests.LOGGER.info( + f"nibid tx pricefeed post-price:\n{tests.format_response(tx_output)}" + ) + transaction_must_succeed(tx_output) + + # Repeating post_price transaction. + # Otherwise, getting "All input prices are expired" on query.pricefeed.price() + if sdk_oracle.address not in WHITELISTED_ORACLES.keys(): + tests.LOGGER.info(f"oracle address not whitelisted: {sdk_oracle.address}") + with pytest.raises(Exception) as err: + tx_output = post_price_test_tx(sdk=sdk_oracle) + err_msg = str(err) + assert transaction_must_succeed(tx_output) is None, err_msg + tx_output = post_price_test_tx(sdk=sdk_oracle) + tests.LOGGER.info( + f"nibid tx pricefeed post-price:\n{tests.format_response(tx_output)}" + ) + assert transaction_must_succeed(tx_output) is None + sdk_oracle.query.wait_for_next_block() - tests.LOGGER.info("Transaction post_price must succeed") - tx_output = post_price_test_tx(sdk=val_node) - tests.LOGGER.info( - f"nibid tx pricefeed post-price:\n{tests.format_response(tx_output)}" - ) - transaction_must_succeed(tx_output) - - # Repeating post_price transaction. - # Otherwise, getting "All input prices are expired" on query.pricefeed.price() - if val_node.address not in WHITELISTED_ORACLES: - tests.LOGGER.info(f"oracle address not whitelisted: {val_node.address}") - with pytest.raises(Exception) as err: - tx_output = post_price_test_tx(sdk=val_node) - err_msg = str(err) - assert transaction_must_succeed(tx_output) is None, err_msg - tx_output = post_price_test_tx(sdk=val_node) - tests.LOGGER.info( - f"nibid tx pricefeed post-price:\n{tests.format_response(tx_output)}" - ) - assert transaction_must_succeed(tx_output) is None + +@pytest.mark.order(after="TestPostPrice::test_post_prices") +def test_pricefeed_queries(sdk_oracle: nibiru.Sdk): # Raw prices must exist after post_price transaction - raw_prices = val_node.query.pricefeed.raw_prices("unibi:unusd")["raw_prices"] + raw_prices: List[dict] = sdk_oracle.query.pricefeed.raw_prices("ueth:unusd")[ + "raw_prices" + ] assert len(raw_prices) >= 1 # Raw price must be a dict with specific keys @@ -125,24 +135,24 @@ def test_post_prices(val_node: nibiru.Sdk): dict_keys_must_match(raw_price, ['expiry', 'oracle_address', 'pair_id', 'price']) # Price feed params must be a dict with specific keys - price_feed_params = val_node.query.pricefeed.params()["params"] + price_feed_params = sdk_oracle.query.pricefeed.params()["params"] tests.LOGGER.info( f"nibid query pricefeed params:\n{tests.format_response(price_feed_params)}" ) dict_keys_must_match(price_feed_params, ['pairs', 'twap_lookback_window']) - # Unibi price object must be a dict with specific keys - unibi_price = val_node.query.pricefeed.price("unibi:unusd")["price"] + # ueth price object must be a dict with specific keys + ueth_price = sdk_oracle.query.pricefeed.price("ueth:unusd")["price"] tests.LOGGER.info( - f"nibid query pricefeed price:\n{tests.format_response(unibi_price)}" + f"nibid query pricefeed price:\n{tests.format_response(ueth_price)}" ) - dict_keys_must_match(unibi_price, ["pair_id", "price", "twap"]) + dict_keys_must_match(ueth_price, ["pair_id", "price", "twap"]) - # At least one pair in prices must be unibi:unusd - prices = val_node.query.pricefeed.prices()["prices"] + # At least one pair in prices must be ueth:unusd + prices = sdk_oracle.query.pricefeed.prices()["prices"] tests.LOGGER.info(f"nibid query pricefeed prices:\n{tests.format_response(prices)}") - assert any([price["pair_id"] == "unibi:unusd" for price in prices]) + assert any([price["pair_id"] == "ueth:unusd" for price in prices]) - # Unibi price object must be a dict with specific keys - unibi_price = next(price for price in prices if price["pair_id"] == "unibi:unusd") - dict_keys_must_match(unibi_price, ["pair_id", "price", "twap"]) + # ueth price object must be a dict with specific keys + ueth_price = next(price for price in prices if price["pair_id"] == "ueth:unusd") + dict_keys_must_match(ueth_price, ["pair_id", "price", "twap"]) diff --git a/tests/staking_test.py b/tests/staking_test.py index 92af289b..f11a9fc0 100644 --- a/tests/staking_test.py +++ b/tests/staking_test.py @@ -9,48 +9,48 @@ from tests import dict_keys_must_match, transaction_must_succeed -def get_validator_operator_address(val_node: Sdk): +def get_validator_operator_address(sdk_val: Sdk): """ Return the first validator and delegator """ - validator = val_node.query.staking.validators()["validators"][0] + validator = sdk_val.query.staking.validators()["validators"][0] return validator["operator_address"] -def delegate(val_node: Sdk): - return val_node.tx.execute_msgs( +def delegate(sdk_val: Sdk): + return sdk_val.tx.execute_msgs( [ MsgDelegate( - delegator_address=val_node.address, - validator_address=get_validator_operator_address(val_node), + delegator_address=sdk_val.address, + validator_address=get_validator_operator_address(sdk_val), amount=1, ), ] ) -def undelegate(val_node: Sdk): - return val_node.tx.execute_msgs( +def undelegate(sdk_val: Sdk): + return sdk_val.tx.execute_msgs( [ MsgUndelegate( - delegator_address=val_node.address, - validator_address=get_validator_operator_address(val_node), + delegator_address=sdk_val.address, + validator_address=get_validator_operator_address(sdk_val), amount=1, ), ] ) -def test_query_vpool(val_node: Sdk): - query_resp = val_node.query.staking.pool() +def test_query_vpool(sdk_val: Sdk): + query_resp = sdk_val.query.staking.pool() assert query_resp["pool"]["bonded_tokens"] >= 0 assert query_resp["pool"]["not_bonded_tokens"] >= 0 -def test_query_delegation(val_node: Sdk): - transaction_must_succeed(delegate(val_node)) - query_resp = val_node.query.staking.delegation( - val_node.address, get_validator_operator_address(val_node) +def test_query_delegation(sdk_val: Sdk): + transaction_must_succeed(delegate(sdk_val)) + query_resp = sdk_val.query.staking.delegation( + sdk_val.address, get_validator_operator_address(sdk_val) ) dict_keys_must_match( query_resp["delegation_response"], @@ -61,9 +61,9 @@ def test_query_delegation(val_node: Sdk): ) -def test_query_delegations(val_node: Sdk): - transaction_must_succeed(delegate(val_node)) - query_resp = val_node.query.staking.delegations(val_node.address) +def test_query_delegations(sdk_val: Sdk): + transaction_must_succeed(delegate(sdk_val)) + query_resp = sdk_val.query.staking.delegations(sdk_val.address) dict_keys_must_match( query_resp["delegation_responses"][0], [ @@ -73,10 +73,10 @@ def test_query_delegations(val_node: Sdk): ) -def test_query_delegations_to(val_node: Sdk): - transaction_must_succeed(delegate(val_node)) - query_resp = val_node.query.staking.delegations_to( - get_validator_operator_address(val_node) +def test_query_delegations_to(sdk_val: Sdk): + transaction_must_succeed(delegate(sdk_val)) + query_resp = sdk_val.query.staking.delegations_to( + get_validator_operator_address(sdk_val) ) dict_keys_must_match( query_resp["delegation_responses"][0], @@ -87,9 +87,9 @@ def test_query_delegations_to(val_node: Sdk): ) -def test_historical_info(val_node: Sdk): +def test_historical_info(sdk_val: Sdk): try: - hist_info = val_node.query.staking.historical_info(1) + hist_info = sdk_val.query.staking.historical_info(1) if hist_info["hist"]["valset"]: dict_keys_must_match( hist_info["hist"]["valset"][0], @@ -111,8 +111,8 @@ def test_historical_info(val_node: Sdk): pass -def test_params(val_node: Sdk): - query_resp = val_node.query.staking.params() +def test_params(sdk_val: Sdk): + query_resp = sdk_val.query.staking.params() dict_keys_must_match( query_resp["params"], [ @@ -125,22 +125,22 @@ def test_params(val_node: Sdk): ) -def test_redelegations(val_node: Sdk): - query_resp = val_node.query.staking.redelegations( - val_node.address, get_validator_operator_address(val_node) +def test_redelegations(sdk_val: Sdk): + query_resp = sdk_val.query.staking.redelegations( + sdk_val.address, get_validator_operator_address(sdk_val) ) dict_keys_must_match(query_resp, ["redelegation_responses", "pagination"]) -def test_unbonding_delegation(val_node: Sdk): - transaction_must_succeed(delegate(val_node)) +def test_unbonding_delegation(sdk_val: Sdk): + transaction_must_succeed(delegate(sdk_val)) try: - undelegate(val_node) + undelegate(sdk_val) except SimulationError as ex: assert "too many unbonding" in ex.args[0] - query_resp = val_node.query.staking.unbonding_delegation( - val_node.address, get_validator_operator_address(val_node) + query_resp = sdk_val.query.staking.unbonding_delegation( + sdk_val.address, get_validator_operator_address(sdk_val) ) if query_resp: dict_keys_must_match( @@ -149,14 +149,14 @@ def test_unbonding_delegation(val_node: Sdk): assert len(query_resp["unbond"]["entries"]) > 0 -def test_unbonding_delegations(val_node: Sdk): - transaction_must_succeed(delegate(val_node)) +def test_unbonding_delegations(sdk_val: Sdk): + transaction_must_succeed(delegate(sdk_val)) try: - undelegate(val_node) + undelegate(sdk_val) except SimulationError as ex: assert "too many unbonding" in ex.args[0] - query_resp = val_node.query.staking.unbonding_delegations(val_node.address) + query_resp = sdk_val.query.staking.unbonding_delegations(sdk_val.address) dict_keys_must_match(query_resp, ["unbonding_responses", "pagination"]) dict_keys_must_match( query_resp["unbonding_responses"][0], @@ -165,15 +165,15 @@ def test_unbonding_delegations(val_node: Sdk): assert len(query_resp["unbonding_responses"][0]["entries"]) > 0 -def test_unbonding_delegations_from(val_node: Sdk): - transaction_must_succeed(delegate(val_node)) +def test_unbonding_delegations_from(sdk_val: Sdk): + transaction_must_succeed(delegate(sdk_val)) try: - undelegate(val_node) + undelegate(sdk_val) except SimulationError as ex: assert "too many unbonding" in ex.args[0] - query_resp = val_node.query.staking.unbonding_delegations_from( - get_validator_operator_address(val_node) + query_resp = sdk_val.query.staking.unbonding_delegations_from( + get_validator_operator_address(sdk_val) ) dict_keys_must_match(query_resp, ["unbonding_responses", "pagination"]) dict_keys_must_match( @@ -183,8 +183,8 @@ def test_unbonding_delegations_from(val_node: Sdk): assert len(query_resp["unbonding_responses"][0]["entries"]) > 0 -def test_validators(val_node: Sdk): - query_resp = val_node.query.staking.validators() +def test_validators(sdk_val: Sdk): + query_resp = sdk_val.query.staking.validators() dict_keys_must_match(query_resp, ["validators", "pagination"]) assert query_resp["pagination"]["total"] > 0 assert len(query_resp["validators"]) > 0 @@ -206,9 +206,9 @@ def test_validators(val_node: Sdk): ) -def test_validator(val_node: Sdk): - validator = val_node.query.staking.validators()["validators"][0] - query_resp = val_node.query.staking.validator(validator["operator_address"]) +def test_validator(sdk_val: Sdk): + validator = sdk_val.query.staking.validators()["validators"][0] + query_resp = sdk_val.query.staking.validator(validator["operator_address"]) dict_keys_must_match( query_resp["validator"], @@ -228,7 +228,7 @@ def test_validator(val_node: Sdk): ) -def test_staking_events(val_node: Sdk, network: Network): +def test_staking_events(sdk_val: Sdk, network: Network): """ Check staking events are properly filtered """ @@ -242,7 +242,7 @@ def test_staking_events(val_node: Sdk, network: Network): nibiru_websocket.start() time.sleep(1) - delegate(val_node) + delegate(sdk_val) time.sleep(5) nibiru_websocket.queue.put(None) diff --git a/tests/utils_test.py b/tests/utils_test.py index 9de0e4c9..2275bb6f 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,12 +1,14 @@ +from typing import List + import pytest from nibiru_proto.proto.cosmos.bank.v1beta1.tx_pb2 import MsgSend from nibiru_proto.proto.perp.v1.tx_pb2 import MsgOpenPosition import nibiru -from nibiru import Coin, common +import tests +from nibiru import Coin, pytypes from nibiru.query_clients.util import get_block_messages, get_msg_pb_by_type_url from nibiru.utils import from_sdk_dec, to_sdk_dec -from tests import dict_keys_must_match @pytest.mark.parametrize( @@ -77,40 +79,22 @@ def test_get_msg_pb_by_type_url(type_url, cls): assert get_msg_pb_by_type_url(type_url) == cls() -def test_get_block_messages(val_node: nibiru.Sdk, agent: nibiru.Sdk): - pair = "ubtc:unusd" - - val_node.tx.execute_msgs( +def test_get_block_messages(sdk_val: nibiru.Sdk, sdk_agent: nibiru.Sdk): + tx_output: pytypes.RawTxResp = sdk_val.tx.execute_msgs( nibiru.msg.MsgSend( - val_node.address, agent.address, [Coin(10000, "unibi"), Coin(100, "unusd")] - ) - ) - tx_output: dict = agent.tx.execute_msgs( - nibiru.msg.MsgOpenPosition( - sender=agent.address, - token_pair=pair, - side=common.Side.BUY, - quote_asset_amount=10, - leverage=10, - base_asset_amount_limit=0, + sdk_val.address, + sdk_agent.address, + [Coin(10000, "unibi"), Coin(100, "unusd")], ) ) height = int(tx_output["height"]) - block_resp = agent.query.get_block_by_height(height) - messages = get_block_messages(block_resp.block) + block_resp = sdk_agent.query.get_block_by_height(height) + messages: List[dict] = get_block_messages(block_resp.block) - msg_open_position = [ - msg for msg in messages if msg["type_url"] == "/nibiru.perp.v1.MsgOpenPosition" - ] - assert len(msg_open_position) > 0 - dict_keys_must_match( - msg_open_position[0]["value"], - [ - "sender", - "token_pair", - "side", - "quote_asset_amount", - "leverage", - "base_asset_amount_limit", - ], + msg = messages[0] + assert isinstance(msg, dict) + assert msg["type_url"] == "/cosmos.bank.v1beta1.MsgSend" + tests.dict_keys_must_match( + msg["value"], + ["from_address", "to_address", "amount"], ) diff --git a/tests/vpool_test.py b/tests/vpool_test.py index 59eb4b80..8dcaf2d1 100644 --- a/tests/vpool_test.py +++ b/tests/vpool_test.py @@ -3,24 +3,24 @@ import nibiru import tests -from nibiru import common +from nibiru import pytypes -def test_query_vpool_reserve_assets(val_node: nibiru.Sdk): +def test_query_vpool_reserve_assets(sdk_val: nibiru.Sdk): expected_pairs: List[str] = ["ubtc:unusd", "ueth:unusd"] for pair in expected_pairs: - query_resp: dict = val_node.query.vpool.reserve_assets(pair) + query_resp: dict = sdk_val.query.vpool.reserve_assets(pair) assert isinstance(query_resp, dict) assert query_resp["base_asset_reserve"] > 0 assert query_resp["quote_asset_reserve"] > 0 -def test_query_vpool_all_pools(agent: nibiru.Sdk): +def test_query_vpool_all_pools(sdk_agent: nibiru.Sdk): """Tests deserialization and expected attributes for the 'nibid query vpool all-pools' command. """ - query_resp: Dict[str, List[dict]] = agent.query.vpool.all_pools() + query_resp: Dict[str, List[dict]] = sdk_agent.query.vpool.all_pools() tests.dict_keys_must_match(query_resp, keys=["pools", "prices"]) all_vpools: List[dict] = query_resp["pools"] @@ -53,9 +53,9 @@ def test_query_vpool_all_pools(agent: nibiru.Sdk): tests.LOGGER.info(f"vpool_prices: {pprint.pformat(vpool_prices, indent=3)}") -def test_query_vpool_base_asset_price(agent: nibiru.Sdk): - query_resp: Dict[str, List[dict]] = agent.query.vpool.base_asset_price( - pair="ueth:unusd", direction=common.Direction.ADD, base_asset_amount="15" +def test_query_vpool_base_asset_price(sdk_agent: nibiru.Sdk): + query_resp: Dict[str, List[dict]] = sdk_agent.query.vpool.base_asset_price( + pair="ueth:unusd", direction=pytypes.Direction.ADD, base_asset_amount="15" ) tests.dict_keys_must_match(query_resp, keys=["price_in_quote_denom"]) assert isinstance(query_resp["price_in_quote_denom"], float)