Skip to content

Commit

Permalink
introduce ComparableEnum for tests (#18108)
Browse files Browse the repository at this point in the history
  • Loading branch information
altendky authored Jun 5, 2024
1 parent 6354ea3 commit a37ad22
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
13 changes: 10 additions & 3 deletions chia/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import sysconfig
import tempfile
from contextlib import AsyncExitStack
from enum import IntEnum
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Tuple, Union

import aiohttp
Expand All @@ -29,7 +28,15 @@
from chia._tests.core.data_layer.util import ChiaRoot
from chia._tests.core.node_height import node_height_at_least
from chia._tests.simulation.test_simulation import test_constants_modified
from chia._tests.util.misc import BenchmarkRunner, GcMode, RecordingWebServer, TestId, _AssertRuntime, measure_overhead
from chia._tests.util.misc import (
BenchmarkRunner,
ComparableEnum,
GcMode,
RecordingWebServer,
TestId,
_AssertRuntime,
measure_overhead,
)
from chia._tests.util.setup_nodes import (
OldSimulatorsAndWallets,
SimulatorsAndWallets,
Expand Down Expand Up @@ -187,7 +194,7 @@ def get_keychain():
KeyringWrapper.cleanup_shared_instance()


class ConsensusMode(IntEnum):
class ConsensusMode(ComparableEnum):
PLAIN = 0
SOFT_FORK_4 = 1
HARD_FORK_2_0 = 2
Expand Down
42 changes: 42 additions & 0 deletions chia/_tests/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import sys
from concurrent.futures import Future
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from statistics import mean
from textwrap import dedent
Expand Down Expand Up @@ -637,3 +638,44 @@ class DataTypeProtocol(Protocol):
def unmarshal(cls: Type[T], marshalled: Dict[str, Any]) -> T: ...

def marshal(self) -> Dict[str, Any]: ...


T_ComparableEnum = TypeVar("T_ComparableEnum", bound="ComparableEnum")


class ComparableEnum(Enum):
def __lt__(self: T_ComparableEnum, other: T_ComparableEnum) -> object:
if self.__class__ is not other.__class__:
return NotImplemented

return self.value.__lt__(other.value)

def __le__(self: T_ComparableEnum, other: T_ComparableEnum) -> object:
if self.__class__ is not other.__class__:
return NotImplemented

return self.value.__le__(other.value)

def __eq__(self: T_ComparableEnum, other: object) -> bool:
if self.__class__ is not other.__class__:
return False

return cast(bool, self.value.__eq__(cast(T_ComparableEnum, other).value))

def __ne__(self: T_ComparableEnum, other: object) -> bool:
if self.__class__ is not other.__class__:
return True

return cast(bool, self.value.__ne__(cast(T_ComparableEnum, other).value))

def __gt__(self: T_ComparableEnum, other: T_ComparableEnum) -> object:
if self.__class__ is not other.__class__:
return NotImplemented

return self.value.__gt__(other.value)

def __ge__(self: T_ComparableEnum, other: T_ComparableEnum) -> object:
if self.__class__ is not other.__class__:
return NotImplemented

return self.value.__ge__(other.value)

0 comments on commit a37ad22

Please sign in to comment.