diff --git a/core/SConscript.firmware b/core/SConscript.firmware index 41187ba53f..b16045b0ac 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -557,6 +557,7 @@ if FROZEN: )) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=[ diff --git a/core/SConscript.unix b/core/SConscript.unix index 73152e1cc6..39e4a845bc 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -613,6 +613,7 @@ if FROZEN: )) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/wire/codec/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'storage/*.py', exclude=[ diff --git a/core/embed/rust/src/ui/model_tt/layout.rs b/core/embed/rust/src/ui/model_tt/layout.rs index 2b5d3aa401..0d76521b05 100644 --- a/core/embed/rust/src/ui/model_tt/layout.rs +++ b/core/embed/rust/src/ui/model_tt/layout.rs @@ -1731,7 +1731,7 @@ pub static mp_module_trezorui2: Module = obj_module! { /// """Calls drop on contents of the root component.""" /// /// class UiResult: - /// """Result of an UI operation.""" + /// """Result of a UI operation.""" /// pass /// /// mock:global diff --git a/core/mocks/generated/trezorui2.pyi b/core/mocks/generated/trezorui2.pyi index 1101c7752d..e61d6ad70d 100644 --- a/core/mocks/generated/trezorui2.pyi +++ b/core/mocks/generated/trezorui2.pyi @@ -1144,7 +1144,7 @@ class LayoutObj(Generic[T]): # rust/src/ui/model_tt/layout.rs class UiResult: - """Result of an UI operation.""" + """Result of a UI operation.""" pass CONFIRMED: UiResult CANCELLED: UiResult diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 0d643fcbfb..c405f7017f 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -47,6 +47,10 @@ import storage storage.cache import storage.cache +storage.cache_codec +import storage.cache_codec +storage.cache_common +import storage.cache_common storage.common import storage.common storage.debug @@ -201,12 +205,20 @@ import trezor.utils trezor.wire import trezor.wire -trezor.wire.codec_v1 -import trezor.wire.codec_v1 +trezor.wire.codec +import trezor.wire.codec +trezor.wire.codec.codec_context +import trezor.wire.codec.codec_context +trezor.wire.codec.codec_v1 +import trezor.wire.codec.codec_v1 trezor.wire.context import trezor.wire.context trezor.wire.errors import trezor.wire.errors +trezor.wire.message_handler +import trezor.wire.message_handler +trezor.wire.protocol_common +import trezor.wire.protocol_common trezor.workflow import trezor.workflow apps @@ -309,6 +321,8 @@ import apps.common.backup apps.common.backup_types import apps.common.backup_types +apps.common.cache +import apps.common.cache apps.common.cbor import apps.common.cbor apps.common.coininfo diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 25015459cf..5552fc86ba 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -1,11 +1,13 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache import storage.device as storage_device +from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED from trezor import TR, config, utils, wire, workflow from trezor.enums import HomescreenFormat, MessageType from trezor.messages import Success, UnlockPath from trezor.ui.layouts import confirm_action +from trezor.wire import context +from trezor.wire.message_handler import filters, remove_filter from . import workflow_handlers @@ -34,7 +36,7 @@ def busy_expiry_ms() -> int: Returns the time left until the busy state expires or 0 if the device is not in the busy state. """ - busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS) if busy_deadline_ms is None: return 0 @@ -203,12 +205,15 @@ def get_features() -> Features: async def handle_Initialize(msg: Initialize) -> Features: - session_id = storage_cache.start_session(msg.session_id) + import storage.cache_codec as cache_codec + + session_id = cache_codec.start_session(msg.session_id) if not utils.BITCOIN_ONLY: - derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO) - have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) + from storage.cache_common import APP_COMMON_DERIVE_CARDANO + derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(APP_COMMON_SEED) if ( have_seed and msg.derive_cardano is not None @@ -216,14 +221,12 @@ async def handle_Initialize(msg: Initialize) -> Features: ): # seed is already derived, and host wants to change derive_cardano setting # => create a new session - storage_cache.end_current_session() - session_id = storage_cache.start_session() + cache_codec.end_current_session() + session_id = cache_codec.start_session() have_seed = False if not have_seed: - storage_cache.set_bool( - storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano) - ) + context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) features = get_features() features.session_id = session_id @@ -252,16 +255,17 @@ async def handle_SetBusy(msg: SetBusy) -> Success: import utime deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) - storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) + context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline) else: - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() workflow.close_others() return Success() async def handle_EndSession(msg: EndSession) -> Success: - storage_cache.end_current_session() + ctx = context.get_context() + ctx.release() return Success() @@ -276,7 +280,7 @@ async def handle_Ping(msg: Ping) -> Success: async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: from trezor.messages import PreauthorizedRequest - from trezor.wire.context import call_any, get_context + from trezor.wire.context import call_any from apps.common import authorization @@ -289,11 +293,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: req = await call_any(PreauthorizedRequest(), *wire_types) assert req.MESSAGE_WIRE_TYPE is not None - handler = workflow_handlers.find_registered_handler( - get_context().iface, req.MESSAGE_WIRE_TYPE - ) + handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE) if handler is None: - return wire.unexpected_message() + return wire.message_handler.unexpected_message() return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument] @@ -301,7 +303,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType: from trezor.crypto import hmac from trezor.messages import UnlockedPathRequest - from trezor.wire.context import call_any, get_context + from trezor.wire.context import call_any from apps.common.paths import SLIP25_PURPOSE from apps.common.seed import Slip21Node, get_seed @@ -342,9 +344,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType: req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types) assert req.MESSAGE_WIRE_TYPE in wire_types - handler = workflow_handlers.find_registered_handler( - get_context().iface, req.MESSAGE_WIRE_TYPE - ) + handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE) assert handler is not None return await handler(req, msg) # type: ignore [Expected 1 positional argument] @@ -364,7 +364,7 @@ def set_homescreen() -> None: set_default = workflow.set_default # local_cache_attribute - if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): + if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS): from apps.homescreen import busyscreen set_default(busyscreen) @@ -393,7 +393,7 @@ def set_homescreen() -> None: def lock_device(interrupt_workflow: bool = True) -> None: if config.has_pin(): config.lock() - wire.filters.append(_pinlock_filter) + filters.append(_pinlock_filter) set_homescreen() if interrupt_workflow: workflow.close_others() @@ -429,7 +429,7 @@ async def unlock_device() -> None: _SCREENSAVER_IS_ON = False set_homescreen() - wire.remove_filter(_pinlock_filter) + remove_filter(_pinlock_filter) def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: @@ -450,7 +450,9 @@ def reload_settings_from_storage() -> None: workflow.idle_timer.set( storage_device.get_autolock_delay_ms(), lock_device_if_unlocked ) - wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features() + wire.message_handler.EXPERIMENTAL_ENABLED = ( + storage_device.get_experimental_features() + ) if ui.display.orientation() != storage_device.get_rotation(): ui.backlight_fade(ui.BacklightLevels.DIM) ui.display.orientation(storage_device.get_rotation()) @@ -482,4 +484,4 @@ def boot() -> None: backup.activate_repeated_backup() if not config.is_unlocked(): # pinlocked handler should always be the last one - wire.filters.append(_pinlock_filter) + filters.append(_pinlock_filter) diff --git a/core/src/apps/bitcoin/sign_tx/payment_request.py b/core/src/apps/bitcoin/sign_tx/payment_request.py index 8f2f7b88a8..779646cc1c 100644 --- a/core/src/apps/bitcoin/sign_tx/payment_request.py +++ b/core/src/apps/bitcoin/sign_tx/payment_request.py @@ -1,7 +1,7 @@ from micropython import const from typing import TYPE_CHECKING -from trezor.wire import DataError +from trezor.wire import DataError, context from .. import writers @@ -26,7 +26,7 @@ class PaymentRequestVerifier: def __init__( self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain ) -> None: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto.hashlib import sha256 from trezor.utils import HashWriter @@ -42,9 +42,9 @@ def __init__( if msg.nonce: nonce = bytes(msg.nonce) - if cache.get(cache.APP_COMMON_NONCE) != nonce: + if context.cache_get(APP_COMMON_NONCE) != nonce: raise DataError("Invalid nonce in payment request.") - cache.delete(cache.APP_COMMON_NONCE) + context.cache_delete(APP_COMMON_NONCE) else: nonce = b"" if msg.memos: diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 06b662c87b..35f6b3f60c 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -1,8 +1,14 @@ from typing import TYPE_CHECKING -from storage import cache, device +import storage.device as device +from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_CARDANO_ICARUS_TREZOR_SECRET, + APP_COMMON_DERIVE_CARDANO, +) from trezor import wire from trezor.crypto import cardano +from trezor.wire import context from apps.common import mnemonic from apps.common.seed import get_seed @@ -112,7 +118,7 @@ def is_minting_path(path: Bip32Path) -> bool: def derive_and_store_secrets(passphrase: str) -> None: assert device.is_initialized() - assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO) + assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) if not mnemonic.is_bip39(): # nothing to do for SLIP-39, where we can derive the root from the main seed @@ -132,8 +138,8 @@ def derive_and_store_secrets(passphrase: str) -> None: else: icarus_trezor_secret = icarus_secret - cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret) - cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) + context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret) + context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: @@ -148,19 +154,19 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai seed = await get_seed() return Keychain(cardano.from_seed_ledger(seed)) - if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO): + if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO): raise wire.ProcessError("Cardano derivation is not enabled for this session") if derivation_type == CardanoDerivationType.ICARUS: - cache_entry = cache.APP_CARDANO_ICARUS_SECRET + cache_entry = APP_CARDANO_ICARUS_SECRET else: - cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET + cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET # _get_secret - secret = cache.get(cache_entry) + secret = context.cache_get(cache_entry) if secret is None: await derive_and_store_roots() - secret = cache.get(cache_entry) + secret = context.cache_get(cache_entry) assert secret is not None root = cardano.from_secret(secret) diff --git a/core/src/apps/common/authorization.py b/core/src/apps/common/authorization.py index 4d6e58e4d6..08c7de393e 100644 --- a/core/src/apps/common/authorization.py +++ b/core/src/apps/common/authorization.py @@ -1,23 +1,20 @@ from typing import Iterable -import storage.cache as storage_cache +from storage.cache_common import ( + APP_COMMON_AUTHORIZATION_DATA, + APP_COMMON_AUTHORIZATION_TYPE, +) from trezor import protobuf from trezor.enums import MessageType +from trezor.wire import context WIRE_TYPES: dict[int, tuple[int, ...]] = { MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof), } -APP_COMMON_AUTHORIZATION_DATA = ( - storage_cache.APP_COMMON_AUTHORIZATION_DATA -) # global_import_cache -APP_COMMON_AUTHORIZATION_TYPE = ( - storage_cache.APP_COMMON_AUTHORIZATION_TYPE -) # global_import_cache - def is_set() -> bool: - return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None + return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None def set(auth_message: protobuf.MessageType) -> None: @@ -29,27 +26,27 @@ def set(auth_message: protobuf.MessageType) -> None: # (because only wire-level messages have wire_type, which we use as identifier) ensure(auth_message.MESSAGE_WIRE_TYPE is not None) assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too - storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) - storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer) + context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) + context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer) def get() -> protobuf.MessageType | None: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if not stored_auth_type: return None - buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"") + buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"") return protobuf.load_message_buffer(buffer, stored_auth_type) def is_set_any_session(auth_type: MessageType) -> bool: - return auth_type in storage_cache.get_int_all_sessions( + return auth_type in context.cache_get_int_all_sessions( APP_COMMON_AUTHORIZATION_TYPE ) def get_wire_types() -> Iterable[int]: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if stored_auth_type is None: return () @@ -57,5 +54,5 @@ def get_wire_types() -> Iterable[int]: def clear() -> None: - storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE) - storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA) + context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE) + context.cache_delete(APP_COMMON_AUTHORIZATION_DATA) diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py index f0ec4af519..fc56f42f9b 100644 --- a/core/src/apps/common/backup.py +++ b/core/src/apps/common/backup.py @@ -1,25 +1,27 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache +from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED from trezor import wire from trezor.enums import MessageType +from trezor.wire import context +from trezor.wire.message_handler import filters, remove_filter if TYPE_CHECKING: from trezor.wire import Handler, Msg def repeated_backup_enabled() -> bool: - return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + return context.cache_get_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) def activate_repeated_backup() -> None: - storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) - wire.filters.append(_repeated_backup_filter) + context.cache_set_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) + filters.append(_repeated_backup_filter) def deactivate_repeated_backup() -> None: - storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) - wire.remove_filter(_repeated_backup_filter) + context.cache_delete(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) + remove_filter(_repeated_backup_filter) _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( diff --git a/core/src/apps/common/cache.py b/core/src/apps/common/cache.py new file mode 100644 index 0000000000..7c493705d6 --- /dev/null +++ b/core/src/apps/common/cache.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING + +from trezor.wire import context + +if TYPE_CHECKING: + from typing import Awaitable, Callable, ParamSpec + + P = ParamSpec("P") + ByteFunc = Callable[P, bytes] + AsyncByteFunc = Callable[P, Awaitable[bytes]] + + +def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: + """ + Caches the result of a function call based on the given key. + + - If the key is already present in the cache, the cached value is returned + directly without invoking the decorated function. + + - If the key is not present in the cache, the decorated function is executed, + and its result is stored in the cache before being returned to the caller. + """ + + def decorator(func: ByteFunc[P]) -> ByteFunc[P]: + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: + value = context.cache_get(key) + if value is None: + value = func(*args, **kwargs) + context.cache_set(key, value) + return value + + return wrapper + + return decorator + + +def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: + """ + Caches the result of an async function call based on the given key. + + - If the key is already present in the cache, the cached value is returned + directly without invoking the decorated asynchronous function. + + - If the key is not present in the cache, the decorated asynchronous function + is executed, and its result is stored in the cache before being returned + to the caller. + """ + + def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: + value = context.cache_get(key) + if value is None: + value = await func(*args, **kwargs) + context.cache_set(key, value) + return value + + return wrapper + + return decorator diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index 95afa1b8fb..988d828733 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -1,9 +1,10 @@ import utime from typing import Any, NoReturn -import storage.cache as storage_cache +from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK from trezor import TR, config, utils, wire from trezor.ui.layouts import show_error_and_raise +from trezor.wire import context async def _request_sd_salt( @@ -77,7 +78,7 @@ async def request_pin_and_sd_salt( def _set_last_unlock_time() -> None: now = utime.ticks_ms() - storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) + context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) _DEF_ARG_PIN_ENTER: str = TR.pin__enter @@ -91,7 +92,7 @@ async def verify_user_pin( ) -> None: # _get_last_unlock_time last_unlock = int.from_bytes( - storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" + context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" ) if ( diff --git a/core/src/apps/common/safety_checks.py b/core/src/apps/common/safety_checks.py index dbdff4463e..ddfe841f61 100644 --- a/core/src/apps/common/safety_checks.py +++ b/core/src/apps/common/safety_checks.py @@ -1,15 +1,15 @@ -import storage.cache as storage_cache import storage.device as storage_device -from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY +from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT from trezor.enums import SafetyCheckLevel +from trezor.wire import context def read_setting() -> SafetyCheckLevel: """ Returns the effective safety check level. """ - temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) if temporary_safety_check_level: return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum] else: @@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None: Changes the safety level settings. """ if level == SafetyCheckLevel.Strict: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) elif level == SafetyCheckLevel.PromptAlways: - storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) elif level == SafetyCheckLevel.PromptTemporarily: storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) - storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) + context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) else: raise ValueError("Unknown SafetyCheckLevel") diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 58846b4f9d..b09004ae69 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -1,9 +1,12 @@ from typing import TYPE_CHECKING -import storage.cache as storage_cache import storage.device as storage_device +from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE from trezor import utils from trezor.crypto import hmac +from trezor.wire import context + +from apps.common import cache from . import mnemonic from .passphrase import get as get_passphrase @@ -13,6 +16,12 @@ from .paths import Bip32Path, Slip21Path +if not utils.BITCOIN_ONLY: + from storage.cache_common import ( + APP_CARDANO_ICARUS_SECRET, + APP_COMMON_DERIVE_CARDANO, + ) + class Slip21Node: """ @@ -56,10 +65,10 @@ async def derive_and_store_roots() -> None: if not storage_device.is_initialized(): raise wire.NotInitialized("Device is not initialized") - need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED) - need_cardano_secret = storage_cache.get_bool( - storage_cache.APP_COMMON_DERIVE_CARDANO - ) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET) + need_seed = not context.cache_is_set(APP_COMMON_SEED) + need_cardano_secret = context.cache_get_bool( + APP_COMMON_DERIVE_CARDANO + ) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET) if not need_seed and not need_cardano_secret: return @@ -68,17 +77,17 @@ async def derive_and_store_roots() -> None: if need_seed: common_seed = mnemonic.get_seed(passphrase) - storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed) + context.cache_set(APP_COMMON_SEED, common_seed) if need_cardano_secret: from apps.cardano.seed import derive_and_store_secrets derive_and_store_secrets(passphrase) - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) + @cache.stored_async(APP_COMMON_SEED) async def get_seed() -> bytes: await derive_and_store_roots() - common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED) + common_seed = context.cache_get(APP_COMMON_SEED) assert common_seed is not None return common_seed @@ -86,13 +95,13 @@ async def get_seed() -> bytes: # === Bitcoin-only variant === # We use the simple version of `get_seed` that never needs to derive anything else. - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) + @cache.stored_async(APP_COMMON_SEED) async def get_seed() -> bytes: passphrase = await get_passphrase() return mnemonic.get_seed(passphrase) -@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) +@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) def _get_seed_without_passphrase() -> bytes: if not storage_device.is_initialized(): raise Exception("Device is not initialized") diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 41c65eb85b..3bfd4772e4 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -71,7 +71,7 @@ def wait_until_layout_is_running(timeout: int | None = _DEADLOCK_SLEEP_MS) -> Aw ) async def return_layout_change( - ctx: wire.context.Context, detect_deadlock: bool = False + ctx: wire.protocol_common.Context, detect_deadlock: bool = False ) -> None: # set up the wait storage.layout_watcher = True @@ -356,11 +356,12 @@ async def _no_op(_msg: Any) -> Success: async def handle_session(iface: WireInterface) -> None: from trezor import protobuf, wire - from trezor.wire import codec_v1, context + from trezor.wire.codec import codec_v1 + from trezor.wire.codec.codec_context import CodecContext global DEBUG_CONTEXT - DEBUG_CONTEXT = ctx = context.Context(iface, WIRE_BUFFER_DEBUG) + DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG) if storage.layout_watcher: try: @@ -391,7 +392,7 @@ async def handle_session(iface: WireInterface) -> None: ) if msg.type not in WORKFLOW_HANDLERS: - await ctx.write(wire.unexpected_message()) + await ctx.write(wire.message_handler.unexpected_message()) continue elif req_type is None: @@ -402,7 +403,7 @@ async def handle_session(iface: WireInterface) -> None: await ctx.write(Success()) continue - req_msg = wire.wrap_protobuf_load(msg.data, req_type) + req_msg = wire.message_handler.wrap_protobuf_load(msg.data, req_type) try: res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg) except Exception as exc: diff --git a/core/src/apps/management/get_nonce.py b/core/src/apps/management/get_nonce.py index 2eb9735340..f35852f6d4 100644 --- a/core/src/apps/management/get_nonce.py +++ b/core/src/apps/management/get_nonce.py @@ -5,10 +5,11 @@ async def get_nonce(msg: GetNonce) -> Nonce: - from storage import cache + from storage.cache_common import APP_COMMON_NONCE from trezor.crypto import random from trezor.messages import Nonce + from trezor.wire import context nonce = random.bytes(32) - cache.set(cache.APP_COMMON_NONCE, nonce) + context.cache_set(APP_COMMON_NONCE, nonce) return Nonce(nonce=nonce) diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 68ca529759..9899b3fe6d 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -38,7 +38,7 @@ async def recovery_process() -> Success: recovery_type = storage_recovery.get_type() - wire.AVOID_RESTARTING_FOR = ( + wire.message_handler.AVOID_RESTARTING_FOR = ( MessageType.Initialize, MessageType.GetFeatures, MessageType.EndSession, @@ -59,7 +59,7 @@ async def _continue_repeated_backup() -> None: from apps.common import backup from apps.management.backup_device import perform_backup - wire.AVOID_RESTARTING_FOR = ( + wire.message_handler.AVOID_RESTARTING_FOR = ( MessageType.Initialize, MessageType.GetFeatures, MessageType.EndSession, diff --git a/core/src/apps/monero/live_refresh.py b/core/src/apps/monero/live_refresh.py index 90d2dec642..eb43ad4e7a 100644 --- a/core/src/apps/monero/live_refresh.py +++ b/core/src/apps/monero/live_refresh.py @@ -57,16 +57,17 @@ async def _init_step( msg: MoneroLiveRefreshStartRequest, keychain: Keychain, ) -> MoneroLiveRefreshStartAck: - import storage.cache as storage_cache + from storage.cache_common import APP_MONERO_LIVE_REFRESH from trezor.messages import MoneroLiveRefreshStartAck + from trezor.wire import context from apps.common import paths await paths.validate_path(keychain, msg.address_n) - if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH): + if not context.cache_get_bool(APP_MONERO_LIVE_REFRESH): await layout.require_confirm_live_refresh() - storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True) + context.cache_set_bool(APP_MONERO_LIVE_REFRESH, True) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/core/src/apps/thp/credential_manager.py b/core/src/apps/thp/credential_manager.py index 73c1d0abcd..adf2ba6240 100644 --- a/core/src/apps/thp/credential_manager.py +++ b/core/src/apps/thp/credential_manager.py @@ -7,7 +7,7 @@ ThpCredentialMetadata, ThpPairingCredential, ) -from trezor.wire import wrap_protobuf_load +from trezor.wire.message_handler import wrap_protobuf_load if TYPE_CHECKING: from apps.common.paths import Slip21Path diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py index 6128884ba2..b65c853c93 100644 --- a/core/src/apps/workflow_handlers.py +++ b/core/src/apps/workflow_handlers.py @@ -1,8 +1,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from trezorio import WireInterface - from trezor.wire import Handler, Msg @@ -215,7 +213,7 @@ def _find_message_handler_module(msg_type: int) -> str: raise ValueError -def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None: +def find_registered_handler(msg_type: int) -> Handler | None: if msg_type in workflow_handlers: # Message has a handler available, return it directly. return workflow_handlers[msg_type] diff --git a/core/src/main.py b/core/src/main.py index 87eba676ba..f0c9a54b06 100644 --- a/core/src/main.py +++ b/core/src/main.py @@ -29,7 +29,7 @@ # trezor.pin imports trezor.utils # We need it as an always-active module because trezor.pin.show_pin_timeout is used -# as an UI callback for storage, which can be invoked at any time +# as a UI callback for storage, which can be invoked at any time import trezor.pin # noqa: F401 # === Prepare the USB interfaces first. Do not connect to the host yet. diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 6b4b52ac1e..72d8a1e418 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,39 +1,48 @@ import builtins import gc -from micropython import const -from typing import TYPE_CHECKING -from trezor import utils +from storage import cache_codec +from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache -if TYPE_CHECKING: - from typing import Sequence, TypeVar, overload +# Cache initialization +_SESSIONLESS_CACHE = SessionlessCache() +_PROTOCOL_CACHE = cache_codec +_PROTOCOL_CACHE.initialize() +_SESSIONLESS_CACHE.clear() + +gc.collect() - T = TypeVar("T") +def clear_all() -> None: + """ + Clears all data from both the protocol cache and the sessionless cache. + """ + global autolock_last_touch + autolock_last_touch = None + _SESSIONLESS_CACHE.clear() + _PROTOCOL_CACHE.clear_all() + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + """ + Returns set of int values associated with a given key from all relevant sessions. + + If the key has the `SESSIONLESS_FLAG` set, the values are retrieved + from the sessionless cache. Otherwise, the values are fetched + from the protocol cache. + """ + if key & SESSIONLESS_FLAG: + values = builtins.set() + encoded = _SESSIONLESS_CACHE.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + return _PROTOCOL_CACHE.get_int_all_sessions(key) -_MAX_SESSIONS_COUNT = const(10) -_SESSIONLESS_FLAG = const(128) -_SESSION_ID_LENGTH = const(32) -# Traditional cache keys -APP_COMMON_SEED = const(0) -APP_COMMON_AUTHORIZATION_TYPE = const(1) -APP_COMMON_AUTHORIZATION_DATA = const(2) -APP_COMMON_NONCE = const(3) -if not utils.BITCOIN_ONLY: - APP_COMMON_DERIVE_CARDANO = const(4) - APP_CARDANO_ICARUS_SECRET = const(5) - APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) - APP_MONERO_LIVE_REFRESH = const(7) +def get_sessionless_cache() -> SessionlessCache: + return _SESSIONLESS_CACHE -# Keys that are valid across sessions -APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG) -APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG) -APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | _SESSIONLESS_FLAG) -APP_COMMON_BUSY_DEADLINE_MS = const(3 | _SESSIONLESS_FLAG) -APP_MISC_COSI_NONCE = const(4 | _SESSIONLESS_FLAG) -APP_MISC_COSI_COMMITMENT = const(5 | _SESSIONLESS_FLAG) -APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | _SESSIONLESS_FLAG) # === Homescreen storage === # This does not logically belong to the "cache" functionality, but the cache module is @@ -49,317 +58,3 @@ # Timestamp of last autolock activity. # Here to persist across main loop restart between workflows. autolock_last_touch: int | None = None - - -class InvalidSessionError(Exception): - pass - - -class DataCache: - fields: Sequence[int] # field sizes - - def __init__(self) -> None: - self.data = [bytearray(f + 1) for f in self.fields] - - def set(self, key: int, value: bytes) -> None: - utils.ensure(key < len(self.fields)) - utils.ensure(len(value) <= self.fields[key]) - self.data[key][0] = 1 - self.data[key][1:] = value - - if TYPE_CHECKING: - - @overload - def get(self, key: int) -> bytes | None: ... - - @overload - def get(self, key: int, default: T) -> bytes | T: # noqa: F811 - ... - - def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - utils.ensure(key < len(self.fields)) - if self.data[key][0] != 1: - return default - return bytes(self.data[key][1:]) - - def is_set(self, key: int) -> bool: - utils.ensure(key < len(self.fields)) - return self.data[key][0] == 1 - - def delete(self, key: int) -> None: - utils.ensure(key < len(self.fields)) - self.data[key][:] = b"\x00" - - def clear(self) -> None: - for i in range(len(self.fields)): - self.delete(i) - - -class SessionCache(DataCache): - def __init__(self) -> None: - self.session_id = bytearray(_SESSION_ID_LENGTH) - if utils.BITCOIN_ONLY: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - ) - else: - self.fields = ( - 64, # APP_COMMON_SEED - 2, # APP_COMMON_AUTHORIZATION_TYPE - 128, # APP_COMMON_AUTHORIZATION_DATA - 32, # APP_COMMON_NONCE - 0, # APP_COMMON_DERIVE_CARDANO - 96, # APP_CARDANO_ICARUS_SECRET - 96, # APP_CARDANO_ICARUS_TREZOR_SECRET - 0, # APP_MONERO_LIVE_REFRESH - ) - self.last_usage = 0 - super().__init__() - - def export_session_id(self) -> bytes: - from trezorcrypto import random # avoid pulling in trezor.crypto - - # generate a new session id if we don't have it yet - if not self.session_id: - self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) - # export it as immutable bytes - return bytes(self.session_id) - - def clear(self) -> None: - super().clear() - self.last_usage = 0 - self.session_id[:] = b"" - - -class SessionlessCache(DataCache): - def __init__(self) -> None: - self.fields = ( - 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE - 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY - 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK - 8, # APP_COMMON_BUSY_DEADLINE_MS - 32, # APP_MISC_COSI_NONCE - 32, # APP_MISC_COSI_COMMITMENT - 0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED - ) - super().__init__() - - -# XXX -# Allocation notes: -# Instantiation of a DataCache subclass should make as little garbage as possible, so -# that the preallocated bytearrays are compact in memory. -# That is why the initialization is two-step: first create appropriately sized -# bytearrays, then later call `clear()` on all the existing objects, which resets them -# to zero length. This is producing some trash - `b[:]` allocates a slice. - -_SESSIONS: list[SessionCache] = [] -for _ in range(_MAX_SESSIONS_COUNT): - _SESSIONS.append(SessionCache()) - -_SESSIONLESS_CACHE = SessionlessCache() - -for session in _SESSIONS: - session.clear() -_SESSIONLESS_CACHE.clear() - -gc.collect() - - -_active_session_idx: int | None = None -_session_usage_counter = 0 - - -def start_session(received_session_id: bytes | None = None) -> bytes: - global _active_session_idx - global _session_usage_counter - - if ( - received_session_id is not None - and len(received_session_id) != _SESSION_ID_LENGTH - ): - # Prevent the caller from setting received_session_id=b"" and finding a cleared - # session. More generally, short-circuit the session id search, because we know - # that wrong-length session ids should not be in cache. - # Reduce to "session id not provided" case because that's what we do when - # caller supplies an id that is not found. - received_session_id = None - - _session_usage_counter += 1 - - # attempt to find specified session id - if received_session_id: - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].session_id == received_session_id: - _active_session_idx = i - _SESSIONS[i].last_usage = _session_usage_counter - return received_session_id - - # allocate least recently used session - lru_counter = _session_usage_counter - lru_session_idx = 0 - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].last_usage < lru_counter: - lru_counter = _SESSIONS[i].last_usage - lru_session_idx = i - - _active_session_idx = lru_session_idx - selected_session = _SESSIONS[lru_session_idx] - selected_session.clear() - selected_session.last_usage = _session_usage_counter - return selected_session.export_session_id() - - -def end_current_session() -> None: - global _active_session_idx - - if _active_session_idx is None: - return - - _SESSIONS[_active_session_idx].clear() - _active_session_idx = None - - -def set(key: int, value: bytes) -> None: - if key & _SESSIONLESS_FLAG: - _SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value) - return - if _active_session_idx is None: - raise InvalidSessionError - _SESSIONS[_active_session_idx].set(key, value) - - -def _get_length(key: int) -> int: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG] - elif _active_session_idx is None: - raise InvalidSessionError - else: - return _SESSIONS[_active_session_idx].fields[key] - - -def set_int(key: int, value: int) -> None: - length = _get_length(key) - - encoded = value.to_bytes(length, "big") - - # Ensure that the value fits within the length. Micropython's int.to_bytes() - # doesn't raise OverflowError. - assert int.from_bytes(encoded, "big") == value - - set(key, encoded) - - -def set_bool(key: int, value: bool) -> None: - assert _get_length(key) == 0 # skipping get_length in production build - if value: - set(key, b"") - else: - delete(key) - - -if TYPE_CHECKING: - - @overload - def get(key: int) -> bytes | None: ... - - @overload - def get(key: int, default: T) -> bytes | T: # noqa: F811 - ... - - -def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].get(key, default) - - -def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 - encoded = get(key) - if encoded is None: - return default - else: - return int.from_bytes(encoded, "big") - - -def get_bool(key: int) -> bool: # noqa: F811 - return get(key) is not None - - -def get_int_all_sessions(key: int) -> builtins.set[int]: - sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS - values = builtins.set() - for session in sessions: - encoded = session.get(key) - if encoded is not None: - values.add(int.from_bytes(encoded, "big")) - return values - - -def is_set(key: int) -> bool: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].is_set(key) - - -def delete(key: int) -> None: - if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG) - if _active_session_idx is None: - raise InvalidSessionError - return _SESSIONS[_active_session_idx].delete(key) - - -if TYPE_CHECKING: - from typing import Awaitable, Callable, ParamSpec, TypeVar - - P = ParamSpec("P") - ByteFunc = Callable[P, bytes] - AsyncByteFunc = Callable[P, Awaitable[bytes]] - - -def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: - def decorator(func: ByteFunc[P]) -> ByteFunc[P]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: - value = get(key) - if value is None: - value = func(*args, **kwargs) - set(key, value) - return value - - return wrapper - - return decorator - - -def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: - def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]: - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes: - value = get(key) - if value is None: - value = await func(*args, **kwargs) - set(key, value) - return value - - return wrapper - - return decorator - - -def clear_all() -> None: - global _active_session_idx - global autolock_last_touch - - _active_session_idx = None - _SESSIONLESS_CACHE.clear() - for session in _SESSIONS: - session.clear() - - autolock_last_touch = None diff --git a/core/src/storage/cache_codec.py b/core/src/storage/cache_codec.py new file mode 100644 index 0000000000..7be5989032 --- /dev/null +++ b/core/src/storage/cache_codec.py @@ -0,0 +1,154 @@ +import builtins +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import DataCache +from trezor import utils + +if TYPE_CHECKING: + from typing import TypeVar + + T = TypeVar("T") + + +_MAX_SESSIONS_COUNT = const(10) +SESSION_ID_LENGTH = const(32) + + +class SessionCache(DataCache): + """ + A cache for storing values that depend on seed derivation + or are specific to a `protocol_v1` session. + """ + + def __init__(self) -> None: + self.session_id = bytearray(SESSION_ID_LENGTH) + if utils.BITCOIN_ONLY: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + ) + else: + self.fields = ( + 64, # APP_COMMON_SEED + 2, # APP_COMMON_AUTHORIZATION_TYPE + 128, # APP_COMMON_AUTHORIZATION_DATA + 32, # APP_COMMON_NONCE + 0, # APP_COMMON_DERIVE_CARDANO + 96, # APP_CARDANO_ICARUS_SECRET + 96, # APP_CARDANO_ICARUS_TREZOR_SECRET + 0, # APP_MONERO_LIVE_REFRESH + ) + self.last_usage = 0 + super().__init__() + + def export_session_id(self) -> bytes: + from trezorcrypto import random # avoid pulling in trezor.crypto + + # generate a new session id if we don't have it yet + if not self.session_id: + self.session_id[:] = random.bytes(SESSION_ID_LENGTH) + # export it as immutable bytes + return bytes(self.session_id) + + def clear(self) -> None: + super().clear() + self.last_usage = 0 + self.session_id[:] = b"" + + +_SESSIONS: list[SessionCache] = [] + + +def initialize() -> None: + # Allocation notes: + # Instantiation of any DataCache subclass should make as little garbage + # as possible so that the preallocated bytearrays are compact in memory. + # That is why the initialization is two-step: first, create appropriately + # sized bytearrays, then call `clear()` on all existing objects, which + # resets them to zero length. The `clear()` function uses `arr[:]`, which + # allocates a slice. + global _SESSIONS + for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionCache()) + + for session in _SESSIONS: + session.clear() + + +_active_session_idx: int | None = None +_session_usage_counter = 0 + + +def get_active_session() -> SessionCache | None: + if _active_session_idx is None: + return None + return _SESSIONS[_active_session_idx] + + +def start_session(received_session_id: bytes | None = None) -> bytes: + global _active_session_idx + global _session_usage_counter + + if ( + received_session_id is not None + and len(received_session_id) != SESSION_ID_LENGTH + ): + # Prevent the caller from setting received_session_id=b"" and finding a cleared + # session. More generally, short-circuit the session id search, because we know + # that wrong-length session ids should not be in cache. + # Reduce to "session id not provided" case because that's what we do when + # caller supplies an id that is not found. + received_session_id = None + + _session_usage_counter += 1 + + # attempt to find specified session id + if received_session_id: + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].session_id == received_session_id: + _active_session_idx = i + _SESSIONS[i].last_usage = _session_usage_counter + return received_session_id + + # allocate least recently used session + lru_counter = _session_usage_counter + lru_session_idx = 0 + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].last_usage < lru_counter: + lru_counter = _SESSIONS[i].last_usage + lru_session_idx = i + + _active_session_idx = lru_session_idx + selected_session = _SESSIONS[lru_session_idx] + selected_session.clear() + selected_session.last_usage = _session_usage_counter + return selected_session.export_session_id() + + +def end_current_session() -> None: + global _active_session_idx + + if _active_session_idx is None: + return + + _SESSIONS[_active_session_idx].clear() + _active_session_idx = None + + +def get_int_all_sessions(key: int) -> builtins.set[int]: + values = builtins.set() + for session in _SESSIONS: + encoded = session.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + + +def clear_all() -> None: + global _active_session_idx + _active_session_idx = None + for session in _SESSIONS: + session.clear() diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py new file mode 100644 index 0000000000..90cead81db --- /dev/null +++ b/core/src/storage/cache_common.py @@ -0,0 +1,164 @@ +from micropython import const +from typing import TYPE_CHECKING + +from trezor import utils + +# Traditional cache keys +APP_COMMON_SEED = const(0) +APP_COMMON_AUTHORIZATION_TYPE = const(1) +APP_COMMON_AUTHORIZATION_DATA = const(2) +APP_COMMON_NONCE = const(3) +if not utils.BITCOIN_ONLY: + APP_COMMON_DERIVE_CARDANO = const(4) + APP_CARDANO_ICARUS_SECRET = const(5) + APP_CARDANO_ICARUS_TREZOR_SECRET = const(6) + APP_MONERO_LIVE_REFRESH = const(7) + +# Keys that are valid across sessions +SESSIONLESS_FLAG = const(128) +APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG) +APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG) +APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(2 | SESSIONLESS_FLAG) +APP_COMMON_BUSY_DEADLINE_MS = const(3 | SESSIONLESS_FLAG) +APP_MISC_COSI_NONCE = const(4 | SESSIONLESS_FLAG) +APP_MISC_COSI_COMMITMENT = const(5 | SESSIONLESS_FLAG) +APP_RECOVERY_REPEATED_BACKUP_UNLOCKED = const(6 | SESSIONLESS_FLAG) + + +if TYPE_CHECKING: + from typing import Sequence, TypeVar, overload + + T = TypeVar("T") + + +class InvalidSessionError(Exception): + pass + + +class DataCache: + """ + A single unit of cache storage, designed to store common-type + values efficiently in bytearrays in a sequential manner. + """ + + fields: Sequence[int] # field sizes + + def __init__(self) -> None: + self.data = [bytearray(f + 1) for f in self.fields] + + if TYPE_CHECKING: + + @overload + def get(self, key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def get(self, key: int, default: T) -> bytes | T: # noqa: F811 + ... + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + utils.ensure(key < len(self.fields)) + if self.data[key][0] != 1: + return default + return bytes(self.data[key][1:]) + + def get_bool(self, key: int) -> bool: # noqa: F811 + return self.get(key) is not None + + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + encoded = self.get(key) + if encoded is None: + return default + else: + return int.from_bytes(encoded, "big") + + def is_set(self, key: int) -> bool: + utils.ensure(key < len(self.fields)) + return self.data[key][0] == 1 + + def set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.fields)) + utils.ensure(len(value) <= self.fields[key]) + self.data[key][0] = 1 + self.data[key][1:] = value + + def set_bool(self, key: int, value: bool) -> None: + assert self._get_length(key) == 0 # skipping get_length in production build + if value: + self.set(key, b"") + else: + self.delete(key) + + def set_int(self, key: int, value: int) -> None: + length = self._get_length(key) + encoded = value.to_bytes(length, "big") + + # Ensure that the value fits within the length. Micropython's int.to_bytes() + # doesn't raise OverflowError. + assert int.from_bytes(encoded, "big") == value + + self.set(key, encoded) + + def delete(self, key: int) -> None: + utils.ensure(key < len(self.fields)) + # `arr[:]` allocates a slice to prevent memory fragmentation. + self.data[key][:] = b"\x00" + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) + + def _get_length(self, key: int) -> int: + utils.ensure(key < len(self.fields)) + return self.fields[key] + + +class SessionlessCache(DataCache): + """ + A cache for values that are independent of both + passphrase seed derivation and the active session. + """ + + def __init__(self) -> None: + self.fields = ( + 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE + 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY + 8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK + 8, # APP_COMMON_BUSY_DEADLINE_MS + 32, # APP_MISC_COSI_NONCE + 32, # APP_MISC_COSI_COMMITMENT + 0, # APP_RECOVERY_REPEATED_BACKUP_UNLOCKED + ) + super().__init__() + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + return super().get(key & ~SESSIONLESS_FLAG, default) + + def get_bool(self, key: int) -> bool: # noqa: F811 + return super().get_bool(key & ~SESSIONLESS_FLAG) + + def get_int( + self, key: int, default: T | None = None + ) -> int | T | None: # noqa: F811 + return super().get_int(key & ~SESSIONLESS_FLAG, default) + + def is_set(self, key: int) -> bool: + return super().is_set(key & ~SESSIONLESS_FLAG) + + def set(self, key: int, value: bytes) -> None: + super().set(key & ~SESSIONLESS_FLAG, value) + + def set_bool(self, key: int, value: bool) -> None: + super().set_bool(key & ~SESSIONLESS_FLAG, value) + + def set_int(self, key: int, value: int) -> None: + super().set_int(key & ~SESSIONLESS_FLAG, value) + + def delete(self, key: int) -> None: + super().delete(key & ~SESSIONLESS_FLAG) + + def clear(self) -> None: + for i in range(len(self.fields)): + self.delete(i) diff --git a/core/src/trezor/ui/__init__.py b/core/src/trezor/ui/__init__.py index a0ad35a338..7292fefac2 100644 --- a/core/src/trezor/ui/__init__.py +++ b/core/src/trezor/ui/__init__.py @@ -8,6 +8,7 @@ from trezor import io, log, loop, utils, wire, workflow from trezor.messages import ButtonAck, ButtonRequest from trezor.wire import context +from trezor.wire.protocol_common import Context from trezorui2 import BacklightLevels, LayoutState if TYPE_CHECKING: @@ -166,7 +167,7 @@ def __init__(self, layout: LayoutObj[T]) -> None: self.button_request_ack_pending: bool = False self.transition_out: AttachType | None = None self.backlight_level = BacklightLevels.NORMAL - self.context: context.Context | None = None + self.context: Context | None = None self.state: LayoutState = LayoutState.INITIAL def is_ready(self) -> bool: diff --git a/core/src/trezor/ui/layouts/homescreen.py b/core/src/trezor/ui/layouts/homescreen.py index 0fe5d2e1cf..cce243097f 100644 --- a/core/src/trezor/ui/layouts/homescreen.py +++ b/core/src/trezor/ui/layouts/homescreen.py @@ -2,6 +2,7 @@ import storage.cache as storage_cache import trezorui2 +from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS from trezor import TR, ui if TYPE_CHECKING: @@ -122,11 +123,13 @@ def __init__(self, delay_ms: int) -> None: ) async def get_result(self) -> Any: + from trezor.wire import context + from apps.base import set_homescreen # Handle timeout. result = await super().get_result() assert result == trezorui2.CANCELLED - storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() return result diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 7021759ba9..0162d4b8d5 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -111,6 +111,7 @@ def presize_module(modname: str, size: int) -> None: if __debug__: + from ubinascii import hexlify def mem_dump(filename: str) -> None: from micropython import mem_info @@ -127,6 +128,10 @@ def mem_dump(filename: str) -> None: else: mem_info(True) + def get_bytes_as_str(a: bytes) -> str: + """Converts the provided bytes to a hexadecimal string (decoded as`utf-8`).""" + return hexlify(a).decode("utf-8") + def ensure(cond: bool, msg: str | None = None) -> None: if not cond: diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 9023bbd288..68bfd3d109 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -5,7 +5,7 @@ - Request / response. - Protobuf-encoded, see `protobuf.py`. -- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. - Transferred over USB interface, or UDP in case of Unix emulation. This module: @@ -23,15 +23,13 @@ """ -from micropython import const from typing import TYPE_CHECKING -from storage.cache import InvalidSessionError -from trezor import log, loop, protobuf, utils, workflow -from trezor.enums import FailureType -from trezor.messages import Failure -from trezor.wire import codec_v1, context -from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage +from trezor import log, loop, protobuf, utils +from trezor.wire import message_handler, protocol_common +from trezor.wire.codec.codec_context import CodecContext +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.message_handler import WIRE_BUFFER, failure # Import all errors into namespace, so that `wire.Error` is available from # other packages. @@ -40,158 +38,23 @@ if TYPE_CHECKING: from trezorio import WireInterface - from typing import Any, Callable, Container, Coroutine, TypeVar + from typing import Any, Callable, Coroutine, TypeVar Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] Handler = Callable[[Msg], HandlerTask] - Filter = Callable[[int, Handler], Handler] LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) -# If set to False protobuf messages marked with "experimental_message" option are rejected. -EXPERIMENTAL_ENABLED = False - - def setup(iface: WireInterface) -> None: - """Initialize the wire stack on passed USB interface.""" + """Initialize the wire stack on the provided WireInterface.""" loop.schedule(handle_session(iface)) -def wrap_protobuf_load( - buffer: bytes, - expected_type: type[LoadedMessageType], -) -> LoadedMessageType: - try: - msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) - if __debug__ and utils.EMULATOR: - log.debug( - __name__, "received message contents:\n%s", utils.dump_protobuf(msg) - ) - return msg - except Exception as e: - if __debug__: - log.exception(__name__, e) - if e.args: - raise DataError("Failed to decode message: " + " ".join(e.args)) - else: - raise DataError("Failed to decode message") - - -_PROTOBUF_BUFFER_SIZE = const(8192) - -WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - -if __debug__: - PROTOBUF_BUFFER_SIZE_DEBUG = 1024 - WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) - - -async def _handle_single_message(ctx: context.Context, msg: codec_v1.Message) -> bool: - """Handle a message that was loaded from USB by the caller. - - Find the appropriate handler, run it and write its result on the wire. In case - a problem is encountered at any point, write the appropriate error on the wire. - - The return value indicates whether to override the default restarting behavior. If - `False` is returned, the caller is allowed to clear the loop and restart the - MicroPython machine (see `session.py`). This would lose all state and incurs a cost - in terms of repeated startup time. When handling the message didn't cause any - significant fragmentation (e.g., if decoding the message was skipped), or if - the type of message is supposed to be optimized and not disrupt the running state, - this function will return `True`. - """ - if __debug__: - try: - msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME - except Exception: - msg_type = f"{msg.type} - unknown message type" - log.debug( - __name__, - "%d receive: <%s>", - ctx.iface.iface_num(), - msg_type, - ) - - res_msg: protobuf.MessageType | None = None - - # We need to find a handler for this message type. - try: - handler = find_handler(ctx.iface, msg.type) - except Error as exc: - # Handlers are allowed to exception out. In that case, we can skip decoding - # and return the error. - await ctx.write(failure(exc)) - return True - - if msg.type in workflow.ALLOW_WHILE_LOCKED: - workflow.autolock_interrupts_workflow = False - - # Here we make sure we always respond with a Failure response - # in case of any errors. - try: - # Find a protobuf.MessageType subclass that describes this - # message. Raises if the type is not found. - req_type = protobuf.type_for_wire(msg.type) - - # Try to decode the message according to schema from - # `req_type`. Raises if the message is malformed. - req_msg = wrap_protobuf_load(msg.data, req_type) - - # Create the handler task. - task = handler(req_msg) - - # Run the workflow task. Workflow can do more on-the-wire - # communication inside, but it should eventually return a - # response message, or raise an exception (a rather common - # thing to do). Exceptions are handled in the code below. - res_msg = await workflow.spawn(context.with_context(ctx, task)) - - except context.UnexpectedMessage: - # Workflow was trying to read a message from the wire, and - # something unexpected came in. See Context.read() for - # example, which expects some particular message and raises - # UnexpectedMessage if another one comes in. - # - # We process the unexpected message by aborting the current workflow and - # possibly starting a new one, initiated by that message. (The main usecase - # being, the host does not finish the workflow, we want other callers to - # be able to do their own thing.) - # - # The message is stored in the exception, which we re-raise for the caller - # to process. It is not a standard exception that should be logged and a result - # sent to the wire. - raise - - except BaseException as exc: - # Either: - # - the message had a type that has a registered handler, but does not have - # a protobuf class - # - the message was not valid protobuf - # - workflow raised some kind of an exception while running - # - something canceled the workflow from the outside - if __debug__: - if isinstance(exc, ActionCancelled): - log.debug(__name__, "cancelled: %s", exc.message) - elif isinstance(exc, loop.TaskClosed): - log.debug(__name__, "cancelled: loop task was closed") - else: - log.exception(__name__, exc) - res_msg = failure(exc) - - if res_msg is not None: - # perform the write outside the big try-except block, so that usb write - # problem bubbles up - await ctx.write(res_msg) - - # Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting. - return msg.type in AVOID_RESTARTING_FOR - - async def handle_session(iface: WireInterface) -> None: - ctx = context.Context(iface, WIRE_BUFFER) - next_msg: codec_v1.Message | None = None + ctx = CodecContext(iface, WIRE_BUFFER) + next_msg: protocol_common.Message | None = None # Take a mark of modules that are imported at this point, so we can # roll back and un-import any others. @@ -203,7 +66,7 @@ async def handle_session(iface: WireInterface) -> None: # wait for a new one coming from the wire. try: msg = await ctx.read_from_wire() - except codec_v1.CodecError as exc: + except protocol_common.WireError as exc: if __debug__: log.exception(__name__, exc) await ctx.write(failure(exc)) @@ -216,8 +79,8 @@ async def handle_session(iface: WireInterface) -> None: do_not_restart = False try: - do_not_restart = await _handle_single_message(ctx, msg) - except context.UnexpectedMessage as unexpected: + do_not_restart = await message_handler.handle_single_message(ctx, msg) + except UnexpectedMessageException as unexpected: # The workflow was interrupted by an unexpected message. We need to # process it as if it was a new message... next_msg = unexpected.msg @@ -230,7 +93,7 @@ async def handle_session(iface: WireInterface) -> None: if __debug__: log.exception(__name__, exc) finally: - # Unload modules imported by the workflow. Should not raise. + # Unload modules imported by the workflow. Should not raise. utils.unimport_end(modules) if not do_not_restart: @@ -243,81 +106,3 @@ async def handle_session(iface: WireInterface) -> None: # loop.clear() above. if __debug__: log.exception(__name__, exc) - - -def find_handler(iface: WireInterface, msg_type: int) -> Handler: - import usb - - from apps import workflow_handlers - - handler = workflow_handlers.find_registered_handler(iface, msg_type) - if handler is None: - raise UnexpectedMessage("Unexpected message") - - if __debug__ and iface is usb.iface_debug: - # no filtering allowed for debuglink - return handler - - for filter in filters: - handler = filter(msg_type, handler) - - return handler - - -filters: list[Filter] = [] -"""Filters for the wire handler. - -Filters are applied in order. Each filter gets a message id and a preceding handler. It -must either return a handler (the same one or a modified one), or raise an exception -that gets sent to wire directly. - -Filters are not applied to debug sessions. - -The filters are designed for: - * rejecting messages -- while in Recovery mode, most messages are not allowed - * adding additional behavior -- while device is soft-locked, a PIN screen will be shown - before allowing a message to trigger its original behavior. - -For this, the filters are effectively deny-first. If an earlier filter rejects the -message, the later filters are not called. But if a filter adds behavior, the latest -filter "wins" and the latest behavior triggers first. -Please note that this behavior is really unsuited to anything other than what we are -using it for now. It might be necessary to modify the semantics if we need more complex -usecases. - -NB: `filters` is currently public so callers can have control over where they insert -new filters, but removal should be done using `remove_filter`! -We should, however, change it such that filters must be added using an `add_filter` -and `filters` becomes private! -""" - - -def remove_filter(filter: Filter) -> None: - try: - filters.remove(filter) - except ValueError: - pass - - -AVOID_RESTARTING_FOR: Container[int] = () - - -def failure(exc: BaseException) -> Failure: - if isinstance(exc, Error): - return Failure(code=exc.code, message=exc.message) - elif isinstance(exc, loop.TaskClosed): - return Failure(code=FailureType.ActionCancelled, message="Cancelled") - elif isinstance(exc, InvalidSessionError): - return Failure(code=FailureType.InvalidSession, message="Invalid session") - else: - # NOTE: when receiving generic `FirmwareError` on non-debug build, - # change the `if __debug__` to `if True` to get the full error message. - if __debug__: - message = str(exc) - else: - message = "Firmware error" - return Failure(code=FailureType.FirmwareError, message=message) - - -def unexpected_message() -> Failure: - return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") diff --git a/core/src/trezor/wire/codec/__init__.py b/core/src/trezor/wire/codec/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/core/src/trezor/wire/codec/codec_context.py b/core/src/trezor/wire/codec/codec_context.py new file mode 100644 index 0000000000..2d5a7b7c9a --- /dev/null +++ b/core/src/trezor/wire/codec/codec_context.py @@ -0,0 +1,107 @@ +from typing import TYPE_CHECKING, Awaitable, Container + +from storage import cache_codec +from storage.cache_common import DataCache, InvalidSessionError +from trezor import log, protobuf +from trezor.wire.codec import codec_v1 +from trezor.wire.context import UnexpectedMessageException +from trezor.wire.protocol_common import Context, Message + +if TYPE_CHECKING: + from typing import TypeVar + + from trezor.wire import WireInterface + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + + +class CodecContext(Context): + """ "Wire context" for `protocol_v1`.""" + + def __init__( + self, + iface: WireInterface, + buffer: bytearray, + ) -> None: + self.buffer = buffer + super().__init__(iface) + + def read_from_wire(self) -> Awaitable[Message]: + """Read a whole message from the wire without parsing it.""" + return codec_v1.read_message(self.iface, self.buffer) + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + if __debug__: + log.debug( + __name__, + "%d: expect: %s", + self.iface.iface_num(), + expected_type.MESSAGE_NAME if expected_type else expected_types, + ) + + # Load the full message into a buffer, parse out type and data payload + msg = await self.read_from_wire() + + # If we got a message with unexpected type, raise the message via + # `UnexpectedMessageError` and let the session handler deal with it. + if msg.type not in expected_types: + raise UnexpectedMessageException(msg) + + if expected_type is None: + expected_type = protobuf.type_for_wire(msg.type) + + if __debug__: + log.debug( + __name__, + "%d: read: %s", + self.iface.iface_num(), + expected_type.MESSAGE_NAME, + ) + + # look up the protobuf class and parse the message + from ..message_handler import wrap_protobuf_load + + return wrap_protobuf_load(msg.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + if __debug__: + log.debug( + __name__, + "%d: write: %s", + self.iface.iface_num(), + msg.MESSAGE_NAME, + ) + + # cannot write message without wire type + assert msg.MESSAGE_WIRE_TYPE is not None + + msg_size = protobuf.encoded_length(msg) + + if msg_size <= len(self.buffer): + # reuse preallocated + buffer = self.buffer + else: + # message is too big, we need to allocate a new buffer + buffer = bytearray(msg_size) + + msg_size = protobuf.encode(buffer, msg) + await codec_v1.write_message( + self.iface, + msg.MESSAGE_WIRE_TYPE, + memoryview(buffer)[:msg_size], + ) + + def release(self) -> None: + cache_codec.end_current_session() + + # ACCESS TO CACHE + @property + def cache(self) -> DataCache: + c = cache_codec.get_active_session() + if c is None: + raise InvalidSessionError() + return c diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py similarity index 94% rename from core/src/trezor/wire/codec_v1.py rename to core/src/trezor/wire/codec/codec_v1.py index d4c8aacf84..02ff37f0ea 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from trezor import io, loop, utils +from trezor.wire.protocol_common import Message, WireError if TYPE_CHECKING: from trezorio import WireInterface @@ -16,16 +17,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report -class CodecError(Exception): +class CodecError(WireError): pass -class Message: - def __init__(self, mtype: int, mdata: bytes) -> None: - self.type = mtype - self.data = mdata - - async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: read = loop.wait(iface.iface_num() | io.POLL_READ) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 10248c871a..56df34fbc5 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -15,22 +15,16 @@ from typing import TYPE_CHECKING -from trezor import log, loop, protobuf +from storage import cache +from storage.cache_common import SESSIONLESS_FLAG +from trezor import loop, protobuf -from . import codec_v1 +from .protocol_common import Context, Message if TYPE_CHECKING: - from trezorio import WireInterface - from typing import ( - Any, - Awaitable, - Callable, - Container, - Coroutine, - Generator, - TypeVar, - overload, - ) + from typing import Any, Callable, Coroutine, Generator, TypeVar, overload + + from storage.cache_common import DataCache Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] @@ -41,130 +35,18 @@ T = TypeVar("T") -class UnexpectedMessage(Exception): +class UnexpectedMessageException(Exception): """A message was received that is not part of the current workflow. Utility exception to inform the session handler that the current workflow should be aborted and a new one started as if `msg` was the first message. """ - def __init__(self, msg: codec_v1.Message) -> None: + def __init__(self, msg: Message) -> None: super().__init__() self.msg = msg -class Context: - """Wire context. - - Represents USB communication inside a particular session on a particular interface - (i.e., wire, debug, single BT connection, etc.) - """ - - def __init__(self, iface: WireInterface, buffer: bytearray) -> None: - self.iface = iface - self.buffer = buffer - - def read_from_wire(self) -> Awaitable[codec_v1.Message]: - """Read a whole message from the wire without parsing it.""" - return codec_v1.read_message(self.iface, self.buffer) - - if TYPE_CHECKING: - - @overload - async def read( - self, expected_types: Container[int] - ) -> protobuf.MessageType: ... - - @overload - async def read( - self, expected_types: Container[int], expected_type: type[LoadedMessageType] - ) -> LoadedMessageType: ... - - async def read( - self, - expected_types: Container[int], - expected_type: type[protobuf.MessageType] | None = None, - ) -> protobuf.MessageType: - """Read a message from the wire. - - The read message must be of one of the types specified in `expected_types`. - If only a single type is expected, it can be passed as `expected_type`, - to save on having to decode the type code into a protobuf class. - """ - if __debug__: - log.debug( - __name__, - "%d expect: %s", - self.iface.iface_num(), - expected_type.MESSAGE_NAME if expected_type else expected_types, - ) - - # Load the full message into a buffer, parse out type and data payload - msg = await self.read_from_wire() - - # If we got a message with unexpected type, raise the message via - # `UnexpectedMessageError` and let the session handler deal with it. - if msg.type not in expected_types: - raise UnexpectedMessage(msg) - - if expected_type is None: - expected_type = protobuf.type_for_wire(msg.type) - - if __debug__: - log.debug( - __name__, - "%d read: %s", - self.iface.iface_num(), - expected_type.MESSAGE_NAME, - ) - - # look up the protobuf class and parse the message - from . import wrap_protobuf_load - - return wrap_protobuf_load(msg.data, expected_type) - - async def write(self, msg: protobuf.MessageType) -> None: - """Write a message to the wire.""" - if __debug__: - log.debug( - __name__, - "%d write: %s", - self.iface.iface_num(), - msg.MESSAGE_NAME, - ) - - # cannot write message without wire type - assert msg.MESSAGE_WIRE_TYPE is not None - - msg_size = protobuf.encoded_length(msg) - - if msg_size <= len(self.buffer): - # reuse preallocated - buffer = self.buffer - else: - # message is too big, we need to allocate a new buffer - buffer = bytearray(msg_size) - - msg_size = protobuf.encode(buffer, msg) - - await codec_v1.write_message( - self.iface, - msg.MESSAGE_WIRE_TYPE, - memoryview(buffer)[:msg_size], - ) - - async def call( - self, - msg: protobuf.MessageType, - expected_type: type[LoadedMessageType], - ) -> LoadedMessageType: - assert expected_type.MESSAGE_WIRE_TYPE is not None - - await self.write(msg) - del msg - return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) - - CURRENT_CONTEXT: Context | None = None @@ -254,3 +136,69 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator: send_exc = e else: send_exc = None + + +# ACCESS TO CACHE + +if TYPE_CHECKING: + T = TypeVar("T") + + @overload + def cache_get(key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def cache_get(key: int, default: T) -> bytes | T: # noqa: F811 + ... + + +def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get(key, default) + + +def cache_get_bool(key: int) -> bool: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get_bool(key) + + +def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 + cache = _get_cache_for_key(key) + return cache.get_int(key, default) + + +def cache_get_int_all_sessions(key: int) -> set[int]: + return cache.get_int_all_sessions(key) + + +def cache_is_set(key: int) -> bool: + cache = _get_cache_for_key(key) + return cache.is_set(key) + + +def cache_set(key: int, value: bytes) -> None: + cache = _get_cache_for_key(key) + cache.set(key, value) + + +def cache_set_bool(key: int, value: bool) -> None: + cache = _get_cache_for_key(key) + cache.set_bool(key, value) + + +def cache_set_int(key: int, value: int) -> None: + cache = _get_cache_for_key(key) + cache.set_int(key, value) + + +def cache_delete(key: int) -> None: + cache = _get_cache_for_key(key) + cache.delete(key) + + +def _get_cache_for_key(key: int) -> DataCache: + if key & SESSIONLESS_FLAG: + return cache.get_sessionless_cache() + if CURRENT_CONTEXT: + return CURRENT_CONTEXT.cache + raise Exception("No wire context") diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py new file mode 100644 index 0000000000..67c1da16c2 --- /dev/null +++ b/core/src/trezor/wire/message_handler.py @@ -0,0 +1,234 @@ +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import InvalidSessionError +from trezor import log, loop, protobuf, utils, workflow +from trezor.enums import FailureType +from trezor.messages import Failure +from trezor.wire.context import UnexpectedMessageException, with_context +from trezor.wire.errors import ActionCancelled, DataError, Error, UnexpectedMessage +from trezor.wire.protocol_common import Context, Message + +if TYPE_CHECKING: + from typing import Any, Callable, Container + + from trezor.wire import Handler, LoadedMessageType, WireInterface + + HandlerFinder = Callable[[Any, Any], Handler | None] + Filter = Callable[[int, Handler], Handler] + +# If set to False protobuf messages marked with "experimental_message" option are rejected. +EXPERIMENTAL_ENABLED = False + + +def wrap_protobuf_load( + buffer: bytes, + expected_type: type[LoadedMessageType], +) -> LoadedMessageType: + try: + if __debug__ and utils.EMULATOR and utils.USE_THP: + log.debug( + __name__, + "Buffer to be parsed to a LoadedMessage: %s", + utils.get_bytes_as_str(buffer), + ) + msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) + if __debug__ and utils.EMULATOR: + log.debug( + __name__, "received message contents:\n%s", utils.dump_protobuf(msg) + ) + return msg + except Exception as e: + if __debug__: + log.exception(__name__, e) + if e.args: + raise DataError("Failed to decode message: " + " ".join(e.args)) + else: + raise DataError("Failed to decode message") + + +_PROTOBUF_BUFFER_SIZE = const(8192) + +WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) + + +async def handle_single_message(ctx: Context, msg: Message) -> bool: + """Handle a message that was loaded from a WireInterface by the caller. + + Find the appropriate handler, run it and write its result on the wire. In case + a problem is encountered at any point, write the appropriate error on the wire. + + The return value indicates whether to override the default restarting behavior. If + `False` is returned, the caller is allowed to clear the loop and restart the + MicroPython machine (see `session.py`). This would lose all state and incurs a cost + in terms of repeated startup time. When handling the message didn't cause any + significant fragmentation (e.g., if decoding the message was skipped), or if + the type of message is supposed to be optimized and not disrupt the running state, + this function will return `True`. + """ + if __debug__: + try: + msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME + except Exception: + msg_type = f"{msg.type} - unknown message type" + log.debug( + __name__, + "%d receive: <%s>", + ctx.iface.iface_num(), + msg_type, + ) + + res_msg: protobuf.MessageType | None = None + + # We need to find a handler for this message type. + try: + handler: Handler = find_handler(ctx.iface, msg.type) + except Error as exc: + # Handlers are allowed to exception out. In that case, we can skip decoding + # and return the error. + await ctx.write(failure(exc)) + return True + + if msg.type in workflow.ALLOW_WHILE_LOCKED: + workflow.autolock_interrupts_workflow = False + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + req_type = protobuf.type_for_wire(msg.type) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = wrap_protobuf_load(msg.data, req_type) + + # Create the handler task. + task = handler(req_msg) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + + # Spawn a workflow around the task. This ensures that concurrent + # workflows are shut down. + res_msg = await workflow.spawn(with_context(ctx, task)) + + except UnexpectedMessageException: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessage if another one comes in. + # In order not to lose the message, we return it to the caller. + + # We process the unexpected message by aborting the current workflow and + # possibly starting a new one, initiated by that message. (The main usecase + # being, the host does not finish the workflow, we want other callers to + # be able to do their own thing.) + # + # The message is stored in the exception, which we re-raise for the caller + # to process. It is not a standard exception that should be logged and a result + # sent to the wire. + raise + except BaseException as exc: + # Either: + # - the message had a type that has a registered handler, but does not have + # a protobuf class + # - the message was not valid protobuf + # - workflow raised some kind of an exception while running + # - something canceled the workflow from the outside + if __debug__: + if isinstance(exc, ActionCancelled): + log.debug(__name__, "cancelled: %s", exc.message) + elif isinstance(exc, loop.TaskClosed): + log.debug(__name__, "cancelled: loop task was closed") + else: + log.exception(__name__, exc) + res_msg = failure(exc) + + if res_msg is not None: + # perform the write outside the big try-except block, so that usb write + # problem bubbles up + await ctx.write(res_msg) + + # Look into `AVOID_RESTARTING_FOR` to see if this message should avoid restarting. + return msg.type in AVOID_RESTARTING_FOR + + +AVOID_RESTARTING_FOR: Container[int] = () + + +def failure(exc: BaseException) -> Failure: + if isinstance(exc, Error): + return Failure(code=exc.code, message=exc.message) + elif isinstance(exc, loop.TaskClosed): + return Failure(code=FailureType.ActionCancelled, message="Cancelled") + elif isinstance(exc, InvalidSessionError): + return Failure(code=FailureType.InvalidSession, message="Invalid session") + else: + # NOTE: when receiving generic `FirmwareError` on non-debug build, + # change the `if __debug__` to `if True` to get the full error message. + if __debug__: + message = str(exc) + else: + message = "Firmware error" + return Failure(code=FailureType.FirmwareError, message=message) + + +def unexpected_message() -> Failure: + return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") + + +def find_handler(iface: WireInterface, msg_type: int) -> Handler: + import usb + + from apps import workflow_handlers + + handler = workflow_handlers.find_registered_handler(msg_type) + if handler is None: + raise UnexpectedMessage("Unexpected message") + + if __debug__ and iface is usb.iface_debug: + # no filtering allowed for debuglink + return handler + + for filter in filters: + handler = filter(msg_type, handler) + + return handler + + +filters: list[Filter] = [] +"""Filters for the wire handler. + +Filters are applied in order. Each filter gets a message id and a preceding handler. It +must either return a handler (the same one or a modified one), or raise an exception +that gets sent to wire directly. + +Filters are not applied to debug sessions. + +The filters are designed for: + * rejecting messages -- while in Recovery mode, most messages are not allowed + * adding additional behavior -- while device is soft-locked, a PIN screen will be shown + before allowing a message to trigger its original behavior. + +For this, the filters are effectively deny-first. If an earlier filter rejects the +message, the later filters are not called. But if a filter adds behavior, the latest +filter "wins" and the latest behavior triggers first. +Please note that this behavior is really unsuited to anything other than what we are +using it for now. It might be necessary to modify the semantics if we need more complex +usecases. + +NB: `filters` is currently public so callers can have control over where they insert +new filters, but removal should be done using `remove_filter`! +We should, however, change it such that filters must be added using an `add_filter` +and `filters` becomes private! +""" + + +def remove_filter(filter: Filter) -> None: + try: + filters.remove(filter) + except ValueError: + pass diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py new file mode 100644 index 0000000000..ed4105517b --- /dev/null +++ b/core/src/trezor/wire/protocol_common.py @@ -0,0 +1,98 @@ +from typing import TYPE_CHECKING + +from trezor import protobuf + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Container, TypeVar, overload + + from storage.cache_common import DataCache + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + T = TypeVar("T") + + +class Message: + """ + Encapsulates protobuf encoded message, where + - `type` is the `WIRE_TYPE` of the message + - `data` is the protobuf encoded message + """ + + def __init__( + self, + message_type: int, + message_data: bytes, + ) -> None: + self.data = message_data + self.type = message_type + + +class Context: + """Wire context. + + Represents communication between the Trezor device and a host within + a specific session over a particular interface (i.e., wire, debug, + single Bluetooth connection, etc.). + """ + + channel_id: bytes + + def __init__(self, iface: WireInterface, channel_id: bytes | None = None) -> None: + self.iface: WireInterface = iface + if channel_id is not None: + self.channel_id = channel_id + + if TYPE_CHECKING: + + @overload + async def read( + self, expected_types: Container[int] + ) -> protobuf.MessageType: ... + + @overload + async def read( + self, expected_types: Container[int], expected_type: type[LoadedMessageType] + ) -> LoadedMessageType: ... + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + """Read a message from the wire. + + The read message must be of one of the types specified in `expected_types`. + If only a single type is expected, it can be passed as `expected_type`, + to save on having to decode the type code into a protobuf class. + """ + ... + + async def write(self, msg: protobuf.MessageType) -> None: + """Write a message to the wire.""" + ... + + async def call( + self, + msg: protobuf.MessageType, + expected_type: type[LoadedMessageType], + ) -> LoadedMessageType: + """Write a message to the wire, then await and return the response message.""" + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + + def release(self) -> None: + """Release resources used by the context, eg. clear context cache.""" + pass + + @property + def cache(self) -> DataCache: + """Access to the backing cache of the context, if the context has any.""" + ... + + +class WireError(Exception): + pass diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index 1252a1bf5f..67b88f8e68 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -1,7 +1,7 @@ import utime from typing import TYPE_CHECKING -import storage.cache +import storage.cache as storage_cache from trezor import log, loop from trezor.enums import MessageType @@ -153,7 +153,7 @@ def close_others() -> None: if not task.is_running(): task.close() - storage.cache.homescreen_shown = None + storage_cache.homescreen_shown = None # if tasks were running, closing the last of them will run start_default @@ -211,11 +211,11 @@ def touch(self, _restore_from_cache: bool = False) -> None: time and saves it to storage.cache. This is done to avoid losing an active timer when workflow restart happens and tasks are lost. """ - if _restore_from_cache and storage.cache.autolock_last_touch is not None: - now = storage.cache.autolock_last_touch + if _restore_from_cache and storage_cache.autolock_last_touch is not None: + now = storage_cache.autolock_last_touch else: now = utime.ticks_ms() - storage.cache.autolock_last_touch = now + storage_cache.autolock_last_touch = now for callback, task in self.tasks.items(): timeout_us = self.timeouts[callback] diff --git a/core/tests/test_apps.bitcoin.approver.py b/core/tests/test_apps.bitcoin.approver.py index 7354a846b1..22888546f0 100644 --- a/core/tests/test_apps.bitcoin.approver.py +++ b/core/tests/test_apps.bitcoin.approver.py @@ -1,6 +1,6 @@ from common import H_, await_result, unittest # isort:skip -import storage.cache +import storage.cache_codec from trezor import wire from trezor.crypto import bip32 from trezor.enums import InputScriptType, OutputScriptType @@ -11,6 +11,8 @@ TxInput, TxOutput, ) +from trezor.wire import context +from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization from apps.bitcoin.sign_tx.approvers import CoinJoinApprover @@ -20,6 +22,13 @@ class TestApprover(unittest.TestCase): + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + context.CURRENT_CONTEXT = None + def setUp(self): self.coin = coins.by_name("Bitcoin") self.fee_rate_percent = 0.3 @@ -47,7 +56,7 @@ def setUp(self): coin_name=self.coin.coin_name, script_type=InputScriptType.SPENDTAPROOT, ) - storage.cache.start_session() + storage.cache_codec.start_session() def make_coinjoin_request(self, inputs): return CoinJoinRequest( diff --git a/core/tests/test_apps.bitcoin.authorization.py b/core/tests/test_apps.bitcoin.authorization.py index 503c181569..03d32651c7 100644 --- a/core/tests/test_apps.bitcoin.authorization.py +++ b/core/tests/test_apps.bitcoin.authorization.py @@ -1,8 +1,10 @@ from common import H_, unittest # isort:skip -import storage.cache +import storage.cache_codec from trezor.enums import InputScriptType from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx +from trezor.wire import context +from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import CoinJoinAuthorization from apps.common import coins @@ -14,6 +16,12 @@ class TestAuthorization(unittest.TestCase): coin = coins.by_name("Bitcoin") + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + context.CURRENT_CONTEXT = None + def setUp(self): self.msg_auth = AuthorizeCoinJoin( coordinator="www.example.com", @@ -26,7 +34,7 @@ def setUp(self): ) self.authorization = CoinJoinAuthorization(self.msg_auth) - storage.cache.start_session() + storage.cache_codec.start_session() def test_ownership_proof_account_depth_mismatch(self): # Account depth mismatch. diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py index 3828a3ebbc..a232a000ae 100644 --- a/core/tests/test_apps.bitcoin.keychain.py +++ b/core/tests/test_apps.bitcoin.keychain.py @@ -1,17 +1,27 @@ from common import * # isort:skip -from storage import cache +from storage import cache_common from trezor import wire from trezor.crypto import bip39 +from trezor.wire import context from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin +from trezor.wire.codec.codec_context import CodecContext +from storage import cache_codec class TestBitcoinKeychain(unittest.TestCase): + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + context.CURRENT_CONTEXT = None + def setUp(self): - cache.start_session() + cache_codec.start_session() seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bitcoin(self): coin = _get_coin_by_name("Bitcoin") @@ -88,10 +98,17 @@ def test_unknown(self): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestAltcoinKeychains(unittest.TestCase): + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + context.CURRENT_CONTEXT = None + def setUp(self): - cache.start_session() + cache_codec.start_session() seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bcash(self): coin = _get_coin_by_name("Bcash") diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py index 84681a0b01..fa2e3ff041 100644 --- a/core/tests/test_apps.common.keychain.py +++ b/core/tests/test_apps.common.keychain.py @@ -1,19 +1,29 @@ from common import * # isort:skip from mock_storage import mock_storage -from storage import cache +from storage import cache, cache_common from trezor import wire from trezor.crypto import bip39 from trezor.enums import SafetyCheckLevel +from trezor.wire import context from apps.common import safety_checks from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.paths import PATTERN_SEP5, PathSchema +from trezor.wire.codec.codec_context import CodecContext +from storage import cache_codec class TestKeychain(unittest.TestCase): + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + context.CURRENT_CONTEXT = None + def setUp(self): - cache.start_session() + cache_codec.start_session() def tearDown(self): cache.clear_all() @@ -71,7 +81,7 @@ def test_no_schemas(self): def test_get_keychain(self): seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + context.cache_set(cache_common.APP_COMMON_SEED, seed) schema = PathSchema.parse("m/44'/1'", 0) keychain = await_result(get_keychain("secp256k1", [schema])) @@ -85,7 +95,7 @@ def test_get_keychain(self): def test_with_slip44(self): seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + context.cache_set(cache_common.APP_COMMON_SEED, seed) slip44_id = 42 valid_path = [H_(44), H_(slip44_id), H_(0)] diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index 53affef1b7..a00b412a5f 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -2,12 +2,15 @@ import unittest -from storage import cache -from trezor import utils, wire +from storage import cache_common +from trezor import wire from trezor.crypto import bip39 +from trezor.wire import context from apps.common.keychain import get_keychain from apps.common.paths import HARDENED +from trezor.wire.codec.codec_context import CodecContext +from storage import cache_codec if not utils.BITCOIN_ONLY: from ethereum_common import encode_network, make_network @@ -71,10 +74,16 @@ def _check_keychain(self, keychain, slip44_id): addr, ) + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + context.CURRENT_CONTEXT = None + def setUp(self): - cache.start_session() + cache_codec.start_session() seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def from_address_n(self, address_n): slip44 = _slip44_from_address_n(address_n) diff --git a/core/tests/test_apps.monero.serializer.py b/core/tests/test_apps.monero.serializer.py index 7f11afbd39..2194a07f98 100644 --- a/core/tests/test_apps.monero.serializer.py +++ b/core/tests/test_apps.monero.serializer.py @@ -11,9 +11,6 @@ @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestMoneroSerializer(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestMoneroSerializer, self).__init__(*args, **kwargs) - def test_varint(self): """ Var int diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 76fe29655b..25eb119bd3 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,150 +1,160 @@ from common import * # isort:skip from mock_storage import mock_storage -from storage import cache +from storage import cache, cache_codec, cache_common from trezor.messages import EndSession, Initialize +from trezor.wire import context +from trezor.wire.codec.codec_context import CodecContext from apps.base import handle_EndSession, handle_Initialize +from apps.common.cache import stored, stored_async KEY = 0 # Function moved from cache.py, as it was not used there def is_session_started() -> bool: - return cache._active_session_idx is not None + return cache_codec._active_session_idx is not None class TestStorageCache(unittest.TestCase): + + def setUpClass(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + + def tearDownClass(self): + context.CURRENT_CONTEXT = None + def setUp(self): cache.clear_all() def test_start_session(self): - session_id_a = cache.start_session() + session_id_a = cache_codec.start_session() self.assertIsNotNone(session_id_a) - session_id_b = cache.start_session() + session_id_b = cache_codec.start_session() self.assertNotEqual(session_id_a, session_id_b) cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.set(KEY, "something") - with self.assertRaises(cache.InvalidSessionError): - cache.get(KEY) + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_set(KEY, "something") + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_get(KEY) def test_end_session(self): - session_id = cache.start_session() + session_id = cache_codec.start_session() self.assertTrue(is_session_started()) - cache.set(KEY, b"A") - cache.end_current_session() + context.cache_set(KEY, b"A") + cache_codec.end_current_session() self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) + self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) # ending an ended session should be a no-op - cache.end_current_session() + cache_codec.end_current_session() self.assertFalse(is_session_started()) - session_id_a = cache.start_session(session_id) + session_id_a = cache_codec.start_session(session_id) # original session no longer exists self.assertNotEqual(session_id_a, session_id) # original session data no longer exists - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) # create a new session - session_id_b = cache.start_session() + session_id_b = cache_codec.start_session() # switch back to original session - session_id = cache.start_session(session_id_a) + session_id = cache_codec.start_session(session_id_a) self.assertEqual(session_id, session_id_a) # end original session - cache.end_current_session() + cache_codec.end_current_session() # switch back to B - session_id = cache.start_session(session_id_b) + session_id = cache_codec.start_session(session_id_b) self.assertEqual(session_id, session_id_b) def test_session_queue(self): - session_id = cache.start_session() - self.assertEqual(cache.start_session(session_id), session_id) - cache.set(KEY, b"A") - for i in range(cache._MAX_SESSIONS_COUNT): - cache.start_session() - self.assertNotEqual(cache.start_session(session_id), session_id) - self.assertIsNone(cache.get(KEY)) + session_id = cache_codec.start_session() + self.assertEqual(cache_codec.start_session(session_id), session_id) + context.cache_set(KEY, b"A") + for i in range(cache_codec._MAX_SESSIONS_COUNT): + cache_codec.start_session() + self.assertNotEqual(cache_codec.start_session(session_id), session_id) + self.assertIsNone(context.cache_get(KEY)) def test_get_set(self): - session_id1 = cache.start_session() - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") + session_id1 = cache_codec.start_session() + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") - session_id2 = cache.start_session() - cache.set(KEY, b"world") - self.assertEqual(cache.get(KEY), b"world") + session_id2 = cache_codec.start_session() + context.cache_set(KEY, b"world") + self.assertEqual(context.cache_get(KEY), b"world") - cache.start_session(session_id2) - self.assertEqual(cache.get(KEY), b"world") - cache.start_session(session_id1) - self.assertEqual(cache.get(KEY), b"hello") + cache_codec.start_session(session_id2) + self.assertEqual(context.cache_get(KEY), b"world") + cache_codec.start_session(session_id1) + self.assertEqual(context.cache_get(KEY), b"hello") cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.get(KEY) + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_get(KEY) def test_get_set_int(self): - session_id1 = cache.start_session() - cache.set_int(KEY, 1234) - self.assertEqual(cache.get_int(KEY), 1234) + session_id1 = cache_codec.start_session() + context.cache_set_int(KEY, 1234) + self.assertEqual(context.cache_get_int(KEY), 1234) - session_id2 = cache.start_session() - cache.set_int(KEY, 5678) - self.assertEqual(cache.get_int(KEY), 5678) + session_id2 = cache_codec.start_session() + context.cache_set_int(KEY, 5678) + self.assertEqual(context.cache_get_int(KEY), 5678) - cache.start_session(session_id2) - self.assertEqual(cache.get_int(KEY), 5678) - cache.start_session(session_id1) - self.assertEqual(cache.get_int(KEY), 1234) + cache_codec.start_session(session_id2) + self.assertEqual(context.cache_get_int(KEY), 5678) + cache_codec.start_session(session_id1) + self.assertEqual(context.cache_get_int(KEY), 1234) cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.get_int(KEY) + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_get_int(KEY) def test_delete(self): - session_id1 = cache.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - cache.delete(KEY) - self.assertIsNone(cache.get(KEY)) - - cache.set(KEY, b"hello") - cache.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - cache.delete(KEY) - self.assertIsNone(cache.get(KEY)) - - cache.start_session(session_id1) - self.assertEqual(cache.get(KEY), b"hello") + session_id1 = cache_codec.start_session() + self.assertIsNone(context.cache_get(KEY)) + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") + context.cache_delete(KEY) + self.assertIsNone(context.cache_get(KEY)) + + context.cache_set(KEY, b"hello") + cache_codec.start_session() + self.assertIsNone(context.cache_get(KEY)) + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") + context.cache_delete(KEY) + self.assertIsNone(context.cache_get(KEY)) + + cache_codec.start_session(session_id1) + self.assertEqual(context.cache_get(KEY), b"hello") def test_decorators(self): run_count = 0 - cache.start_session() + cache_codec.start_session() - @cache.stored(KEY) + @stored(KEY) def func(): nonlocal run_count run_count += 1 return b"foo" # cache is empty - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) self.assertEqual(run_count, 0) self.assertEqual(func(), b"foo") # function was run self.assertEqual(run_count, 1) - self.assertEqual(cache.get(KEY), b"foo") + self.assertEqual(context.cache_get(KEY), b"foo") # function does not run again but returns cached value self.assertEqual(func(), b"foo") self.assertEqual(run_count, 1) - @cache.stored_async(KEY) + @stored_async(KEY) async def async_func(): nonlocal run_count run_count += 1 @@ -154,7 +164,7 @@ async def async_func(): self.assertEqual(await_result(async_func()), b"foo") self.assertEqual(run_count, 1) - cache.start_session() + cache_codec.start_session() self.assertEqual(await_result(async_func()), b"bar") self.assertEqual(run_count, 2) # awaitable is also run only once @@ -162,16 +172,16 @@ async def async_func(): self.assertEqual(run_count, 2) def test_empty_value(self): - cache.start_session() + cache_codec.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"") - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(context.cache_get(KEY)) + context.cache_set(KEY, b"") + self.assertEqual(context.cache_get(KEY), b"") - cache.delete(KEY) + context.cache_delete(KEY) run_count = 0 - @cache.stored(KEY) + @stored(KEY) def func(): nonlocal run_count run_count += 1 @@ -191,7 +201,7 @@ def call_Initialize(**kwargs): return await_result(handle_Initialize(msg)) # calling Initialize without an ID allocates a new one - session_id = cache.start_session() + session_id = cache_codec.start_session() features = call_Initialize() self.assertNotEqual(session_id, features.session_id) @@ -200,31 +210,31 @@ def call_Initialize(**kwargs): self.assertEqual(session_id, features.session_id) # store "hello" - cache.set(KEY, b"hello") + context.cache_set(KEY, b"hello") # check that it is cleared features = call_Initialize() session_id = features.session_id - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) # store "hello" again - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) - self.assertIsNone(cache.get(KEY)) + call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH) + self.assertIsNone(context.cache_get(KEY)) # but resuming a session loads the previous one call_Initialize(session_id=session_id) - self.assertEqual(cache.get(KEY), b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") def test_EndSession(self): - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) - cache.start_session() + self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) + cache_codec.start_session() self.assertTrue(is_session_started()) - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) await_result(handle_EndSession(EndSession())) self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) + self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py similarity index 99% rename from core/tests/test_trezor.wire.codec_v1.py rename to core/tests/test_trezor.wire.codec.codec_v1.py index 1da0ea896b..78675859e2 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -5,7 +5,7 @@ from trezor import io from trezor.loop import wait from trezor.utils import chunks -from trezor.wire import codec_v1 +from trezor.wire.codec import codec_v1 class MockHID: diff --git a/core/tests/unittest.py b/core/tests/unittest.py index 00e398cc27..c9ecab089f 100644 --- a/core/tests/unittest.py +++ b/core/tests/unittest.py @@ -236,38 +236,48 @@ def wasSuccessful(self): def run_class(c, test_result): o = c() + set_up_class = getattr(o, "setUpClass", lambda: None) + tear_down_class = getattr(o, "tearDownClass", lambda: None) set_up = getattr(o, "setUp", lambda: None) tear_down = getattr(o, "tearDown", lambda: None) print("class", c.__qualname__) - for name in dir(o): - if name.startswith("test"): - print(" ", name, end=" ...") - m = getattr(o, name) - try: - try: - set_up() - test_result.testsRun += 1 - retval = m() - if isinstance(retval, generator_type): - raise RuntimeError( - f"{name} must not be a generator (it is async, uses yield or await)." - ) - elif retval is not None: - raise RuntimeError(f"{name} should not return a result.") - finally: - tear_down() - print(f"{OK_COLOR} ok{DEFAULT_COLOR}") - except SkipTest as e: - print(" skipped:", e.args[0]) - test_result.skippedNum += 1 - except AssertionError as e: - print(f"{ERROR_COLOR} failed{DEFAULT_COLOR}") - sys.print_exception(e) - test_result.failuresNum += 1 - except BaseException as e: - print(f"{ERROR_COLOR} errored:{DEFAULT_COLOR}", e) - sys.print_exception(e) - test_result.errorsNum += 1 + try: + set_up_class() + for name in dir(o): + if name.startswith("test"): + run_test_method(o, name, set_up, tear_down, test_result) + finally: + tear_down_class() + + +def run_test_method(o, name, set_up, tear_down, test_result): + print(" ", name, end=" ...") + m = getattr(o, name) + try: + try: + set_up() + test_result.testsRun += 1 + retval = m() + if isinstance(retval, generator_type): + raise RuntimeError( + f"{name} must not be a generator (it is async, uses yield or await)." + ) + elif retval is not None: + raise RuntimeError(f"{name} should not return a result.") + finally: + tear_down() + print(f"{OK_COLOR} ok{DEFAULT_COLOR}") + except SkipTest as e: + print(" skipped:", e.args[0]) + test_result.skippedNum += 1 + except AssertionError as e: + print(f"{ERROR_COLOR} failed{DEFAULT_COLOR}") + sys.print_exception(e) + test_result.failuresNum += 1 + except BaseException as e: + print(f"{ERROR_COLOR} errored:{DEFAULT_COLOR}", e) + sys.print_exception(e) + test_result.errorsNum += 1 def main(module="__main__"): diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 7fd6dae32b..2ec853dfd3 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -100,7 +100,7 @@ def __init__( You have to provide a `transport`, i.e., a raw connection to the device. You can use `trezorlib.transport.get_transport` to find one. - You have to provide an UI implementation for the three kinds of interaction: + You have to provide a UI implementation for the three kinds of interaction: - button request (notify the user that their interaction is needed) - PIN request (on T1, ask the user to input numbers for a PIN matrix) - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for