Skip to content

Commit

Permalink
chore(python): unify session and protocol versions under one IntEnum
Browse files Browse the repository at this point in the history
[no changelog]
  • Loading branch information
M1nd3r committed Nov 29, 2024
1 parent dbb0d44 commit 8250146
Show file tree
Hide file tree
Showing 14 changed files with 49 additions and 51 deletions.
8 changes: 4 additions & 4 deletions python/src/trezorlib/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import click

from .. import exceptions, transport, ui
from ..client import PROTOCOL_V2, TrezorClient
from ..client import ProtocolVersion, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2
Expand Down Expand Up @@ -150,11 +150,11 @@ def get_session(

# Try resume session from id
if self.session_id is not None:
if client.protocol_version is Session.CODEC_V1:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is Session.THP_V2:
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:
Expand Down Expand Up @@ -311,7 +311,7 @@ def trezorctl_command_with_client(
try:
return func(client, *args, **kwargs)
finally:
if client.protocol_version == PROTOCOL_V2:
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
# if not session_was_resumed:
# try:
Expand Down
4 changes: 2 additions & 2 deletions python/src/trezorlib/cli/trezorctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import click

from .. import __version__, log, messages, protobuf
from ..client import TrezorClient
from ..client import ProtocolVersion, TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.thp import channel_database
Expand Down Expand Up @@ -308,7 +308,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
try:
client = get_client(transport)
description = format_device_name(client.features)
if client.protocol_version == Session.THP_V2:
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy:
description = "Device is in use by another process"
Expand Down
16 changes: 10 additions & 6 deletions python/src/trezorlib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import typing as t
from enum import IntEnum

from . import mapping, messages, models
from .mapping import ProtobufMapping
Expand Down Expand Up @@ -48,9 +49,11 @@

LOG = logging.getLogger(__name__)

UNKNOWN = -1
PROTOCOL_V1 = 1
PROTOCOL_V2 = 2

class ProtocolVersion(IntEnum):
UNKNOWN = 0x00
PROTOCOL_V1 = 0x01 # Codec
PROTOCOL_V2 = 0x02 # THP


class TrezorClient:
Expand Down Expand Up @@ -80,12 +83,13 @@ def __init__(
else:
self.protocol = protocol
self.protocol.mapping = self.mapping

if isinstance(self.protocol, ProtocolV1):
self._protocol_version = PROTOCOL_V1
self._protocol_version = ProtocolVersion.PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2):
self._protocol_version = PROTOCOL_V2
self._protocol_version = ProtocolVersion.PROTOCOL_V2
else:
self._protocol_version = UNKNOWN
self._protocol_version = ProtocolVersion.UNKNOWN

@classmethod
def resume(
Expand Down
23 changes: 4 additions & 19 deletions python/src/trezorlib/debuglink.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType
from .tools import expect, parse_path
from .transport.session import Session, SessionV1, SessionV2
from .transport.session import Session, SessionV1
from .transport.thp.protocol_v1 import ProtocolV1

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -1031,18 +1031,12 @@ class SessionDebugWrapper(Session):
def __init__(self, session: Session) -> None:
self._session = session
self.reset_debug_features()
if isinstance(session, SessionV1):
self.client.session_version = 1
elif isinstance(session, SessionV2):
self.client.session_version = 2
elif isinstance(session, SessionDebugWrapper):
if isinstance(session, SessionDebugWrapper):
raise Exception("Cannot wrap already wrapped session!")
else:
self.client.session_version = -1 # UNKNOWN

@property
def session_version(self) -> int:
return self.client.session_version
def protocol_version(self) -> int:
return self.client.protocol_version

@property
def client(self) -> TrezorClientDebugLink:
Expand Down Expand Up @@ -1284,7 +1278,6 @@ class TrezorClientDebugLink(TrezorClient):
# by the device.

def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self._session_version: int = -1
try:
debug_transport = transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact)
Expand All @@ -1311,14 +1304,6 @@ def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self.debug.version = self.version
self.passphrase: str | None = None

@property
def session_version(self) -> int:
return self._session_version

@session_version.setter
def session_version(self, value: int) -> None:
self._session_version = value

@property
def layout_type(self) -> LayoutType:
return self.debug.layout_type
Expand Down
2 changes: 0 additions & 2 deletions python/src/trezorlib/transport/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@


class Session:
CODEC_V1: t.Final[int] = 1
THP_V2: t.Final[int] = 2
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
Expand Down
3 changes: 2 additions & 1 deletion tests/device_tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from trezorlib import device, messages, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client


def test_features(client: Client):
session = client.get_session()
f0 = session.features
if Session(session).session_version == Session.CODEC_V1:
if client.protocol_version == ProtocolVersion.PROTOCOL_V1:
# session erases session_id from its features
f0.session_id = session.id
f1 = session.call(messages.Initialize(session_id=session.id))
Expand Down
3 changes: 2 additions & 1 deletion tests/device_tests/test_debuglink.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest

from trezorlib import debuglink, device, messages, misc
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path
Expand Down Expand Up @@ -62,7 +63,7 @@ def test_pin(session: Session):

@pytest.mark.models("core")
def test_softlock_instability(session: Session):
if session.session_version == Session.THP_V2:
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
raise Exception("THIS NEEDS TO BE CHANGED FOR THP")

def load_device():
Expand Down
3 changes: 2 additions & 1 deletion tests/device_tests/test_msg_applysettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

from trezorlib import btc, device, exceptions, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path

Expand Down Expand Up @@ -205,7 +206,7 @@ def test_apply_homescreen_toif(session: Session):

@pytest.mark.models(skip=["legacy", "safe3"])
def test_apply_homescreen_jpeg(session: Session):
if session.session_version is Session.THP_V2:
if session.protocol_version is ProtocolVersion.PROTOCOL_V2:
raise Exception(
"FAILS BECAUSE THE MESSAGE IS BIGGER THAN THE INTERNAL READ BUFFER"
)
Expand Down
19 changes: 10 additions & 9 deletions tests/device_tests/test_protection_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest

from trezorlib import btc, device, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure
Expand Down Expand Up @@ -61,17 +62,17 @@ def _assert_protection(
client.refresh_features()
assert client.features.pin_protection is pin
assert client.features.passphrase_protection is passphrase
if session.session_version == Session.THP_V2:
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
new_session = session.client.get_session()
session.lock()
session.end()
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
new_session = session.client.get_session()
return Session(new_session)


def test_initialize(session: Session):
if session.session_version == Session.THP_V2:
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
# Test is skipped for THP
return

Expand Down Expand Up @@ -194,7 +195,7 @@ def test_get_public_key(session: Session):
client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]

if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
expected_responses.append(messages.PublicKey)

Expand All @@ -208,7 +209,7 @@ def test_get_address(session: Session):
with session, session.client as client:
client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
expected_responses.append(messages.Address)

Expand Down Expand Up @@ -323,7 +324,7 @@ def test_sign_message(session: Session):

expected_responses = [_pin_request(session)]

if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)

expected_responses.extend(
Expand Down Expand Up @@ -409,7 +410,7 @@ def test_signtx(session: Session):
with session, session.client as client:
client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
expected_responses.extend(
[
Expand Down Expand Up @@ -463,11 +464,11 @@ def test_unlocked(session: Session):
def test_passphrase_cached(session: Session):
session = _assert_protection(session, pin=False)
with session:
if session.session_version == 1:
if session.protocol_version == 1:
session.set_expected_responses(
[messages.PassphraseRequest, messages.Address]
)
elif session.session_version == 2:
elif session.protocol_version == 2:
session.set_expected_responses([messages.Address])
else:
raise Exception("Unknown session type")
Expand Down
3 changes: 2 additions & 1 deletion tests/device_tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from trezorlib import cardano, messages, models
from trezorlib.btc import get_public_node
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure
Expand All @@ -33,7 +34,7 @@

def test_thp_end_session(client: Client):
session = Session(client.get_session())
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
# TODO: This test should be skipped on non-THP builds
return

Expand Down
6 changes: 5 additions & 1 deletion tests/device_tests/test_session_id_and_passphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

from trezorlib import device, exceptions, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
Expand Down Expand Up @@ -65,7 +66,10 @@ def _get_xpub(
]
else:
expected_responses = [messages.PublicKey]
if passphrase_v1 is not None and session.session_version == Session.CODEC_V1:
if (
passphrase_v1 is not None
and session.protocol_version == ProtocolVersion.PROTOCOL_V1
):
session.passphrase = passphrase_v1

with session:
Expand Down
4 changes: 2 additions & 2 deletions tests/persistence_tests/test_shamir_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import pytest

from trezorlib import device, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import DebugLink, LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.messages import RecoveryStatus

from ..click_tests import common, recovery
Expand Down Expand Up @@ -158,7 +158,7 @@ def assert_mnemonic_keyboard(debug: DebugLink) -> None:
layout = debug.read_layout()

# while keyboard is open, hit the device with Initialize/GetFeatures
if device_handler.client.session_version == Session.CODEC_V1:
if device_handler.client.session_version == ProtocolVersion.PROTOCOL_V1:
device_handler.client.get_management_session().call(messages.Initialize())
device_handler.client.refresh_features()

Expand Down
3 changes: 2 additions & 1 deletion tests/upgrade_tests/test_firmware_upgrades.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from shamir_mnemonic import shamir

from trezorlib import btc, debuglink, device, exceptions, fido, messages, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.messages import (
ApplySettings,
Expand Down Expand Up @@ -373,7 +374,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):

# Get a passphrase-less and a passphrased address.
address = btc.get_address(session, "Bitcoin", PATH)
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(new_session=True))
new_session = emu.client.get_session(passphrase="TREZOR")
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)
Expand Down
3 changes: 2 additions & 1 deletion tests/upgrade_tests/test_passphrase_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib._internal.emulator import Emulator
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path

Expand Down Expand Up @@ -139,7 +140,7 @@ def test_init_device(emulator: Emulator):
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
# in TT < 2.3.0 session_id will only be available after PassphraseStateRequest
session_id = session.id
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(session_id=session_id))
btc.get_address(
session,
Expand Down

0 comments on commit 8250146

Please sign in to comment.