diff --git a/chia/_tests/blockchain/blockchain_test_utils.py b/chia/_tests/blockchain/blockchain_test_utils.py index ea39623dc86c..2e2b72decb3a 100644 --- a/chia/_tests/blockchain/blockchain_test_utils.py +++ b/chia/_tests/blockchain/blockchain_test_utils.py @@ -6,6 +6,7 @@ from chia.consensus.blockchain import AddBlockResult, Blockchain from chia.consensus.multiprocess_validation import PreValidationResult from chia.types.full_block import FullBlock +from chia.util.cached_bls import BLSCache from chia.util.errors import Err from chia.util.ints import uint32, uint64 @@ -42,10 +43,12 @@ async def check_block_store_invariant(bc: Blockchain): async def _validate_and_add_block( blockchain: Blockchain, block: FullBlock, + *, expected_result: Optional[AddBlockResult] = None, expected_error: Optional[Err] = None, skip_prevalidation: bool = False, fork_info: Optional[ForkInfo] = None, + use_bls_cache: bool = False, ) -> None: # Tries to validate and add the block, and checks that there are no errors in the process and that the # block is added to the peak. @@ -58,7 +61,8 @@ async def _validate_and_add_block( if skip_prevalidation: results = PreValidationResult(None, uint64(1), None, False, uint32(0)) else: - # Do not change this, validate_signatures must be False + # validate_signatures must be False in order to trigger add_block() to + # validate the signature. pre_validation_results: List[PreValidationResult] = await blockchain.pre_validate_blocks_multiprocessing( [block], {}, validate_signatures=False ) @@ -78,11 +82,16 @@ async def _validate_and_add_block( await check_block_store_invariant(blockchain) return None + if use_bls_cache: + bls_cache = BLSCache(100) + else: + bls_cache = None + ( result, err, _, - ) = await blockchain.add_block(block, results, fork_info=fork_info) + ) = await blockchain.add_block(block, results, bls_cache, fork_info=fork_info) await check_block_store_invariant(blockchain) if expected_error is None and expected_result != AddBlockResult.INVALID_BLOCK: diff --git a/chia/_tests/blockchain/test_blockchain.py b/chia/_tests/blockchain/test_blockchain.py index 43df80e7bdd7..d51e9eed33aa 100644 --- a/chia/_tests/blockchain/test_blockchain.py +++ b/chia/_tests/blockchain/test_blockchain.py @@ -1839,7 +1839,7 @@ async def test_pre_validation( assert res[n].error is None block = blocks_to_validate[n] start_rb = time.time() - result, err, _ = await empty_blockchain.add_block(block, res[n]) + result, err, _ = await empty_blockchain.add_block(block, res[n], None) end_rb = time.time() times_rb.append(end_rb - start_rb) assert err is None @@ -1934,7 +1934,7 @@ async def test_conditions( ) # Ignore errors from pre-validation, we are testing block_body_validation repl_preval_results = replace(pre_validation_results[0], error=None, required_iters=uint64(1)) - code, err, state_change = await b.add_block(blocks[-1], repl_preval_results) + code, err, state_change = await b.add_block(blocks[-1], repl_preval_results, None) assert code == AddBlockResult.NEW_PEAK assert err is None assert state_change is not None @@ -2050,7 +2050,7 @@ async def test_timelock_conditions( [blocks[-1]], {}, validate_signatures=True ) assert pre_validation_results is not None - assert (await b.add_block(blocks[-1], pre_validation_results[0]))[0] == expected + assert (await b.add_block(blocks[-1], pre_validation_results[0], None))[0] == expected if expected == AddBlockResult.NEW_PEAK: # ensure coin was in fact spent @@ -2152,7 +2152,7 @@ async def test_aggsig_garbage( ) # Ignore errors from pre-validation, we are testing block_body_validation repl_preval_results = replace(pre_validation_results[0], error=None, required_iters=uint64(1)) - res, error, state_change = await b.add_block(blocks[-1], repl_preval_results) + res, error, state_change = await b.add_block(blocks[-1], repl_preval_results, None) assert (res, error, state_change.fork_height if state_change else None) == expected @pytest.mark.anyio @@ -2268,7 +2268,7 @@ async def test_ephemeral_timelock( [blocks[-1]], {}, validate_signatures=True ) assert pre_validation_results is not None - assert (await b.add_block(blocks[-1], pre_validation_results[0]))[0] == expected + assert (await b.add_block(blocks[-1], pre_validation_results[0], None))[0] == expected if expected == AddBlockResult.NEW_PEAK: # ensure coin1 was in fact spent @@ -2657,7 +2657,9 @@ async def test_cost_exceeds_max( height=softfork_height, constants=bt.constants, ) - err = (await b.add_block(blocks[-1], PreValidationResult(None, uint64(1), npc_result, True, uint32(0))))[1] + err = (await b.add_block(blocks[-1], PreValidationResult(None, uint64(1), npc_result, True, uint32(0)), None))[ + 1 + ] assert err in [Err.BLOCK_COST_EXCEEDS_MAX] results: List[PreValidationResult] = await b.pre_validate_blocks_multiprocessing( @@ -2722,7 +2724,7 @@ async def test_invalid_cost_in_block( height=softfork_height, constants=bt.constants, ) - _, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0))) + _, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)), None) assert err == Err.INVALID_BLOCK_COST # too low @@ -2749,7 +2751,7 @@ async def test_invalid_cost_in_block( height=softfork_height, constants=bt.constants, ) - _, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0))) + _, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)), None) assert err == Err.INVALID_BLOCK_COST # too high @@ -2778,7 +2780,9 @@ async def test_invalid_cost_in_block( block_generator, max_cost, mempool_mode=False, height=softfork_height, constants=bt.constants ) - result, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0))) + result, err, _ = await b.add_block( + block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)), None + ) assert err == Err.INVALID_BLOCK_COST # when the CLVM program exceeds cost during execution, it will fail with @@ -3220,6 +3224,8 @@ async def test_invalid_agg_sig(self, empty_blockchain: Blockchain, bt: BlockTool # Bad signature fails during add_block await _validate_and_add_block(b, last_block, expected_error=Err.BAD_AGGREGATE_SIGNATURE) + # Also test the same case but when using BLSCache + await _validate_and_add_block(b, last_block, expected_error=Err.BAD_AGGREGATE_SIGNATURE, use_bls_cache=True) # Bad signature also fails in prevalidation preval_results = await b.pre_validate_blocks_multiprocessing([last_block], {}, validate_signatures=True) @@ -3336,7 +3342,7 @@ async def test_long_reorg( assert pre_validation_results[i].error is None if (block.height % 100) == 0: print(f"main chain: {block.height:4} weight: {block.weight}") - (result, err, _) = await b.add_block(block, pre_validation_results[i]) + (result, err, _) = await b.add_block(block, pre_validation_results[i], None) await check_block_store_invariant(b) assert err is None assert result == AddBlockResult.NEW_PEAK @@ -3874,10 +3880,10 @@ async def test_reorg_flip_flop(empty_blockchain: Blockchain, bt: BlockTools) -> preval: List[PreValidationResult] = await b.pre_validate_blocks_multiprocessing( [block1], {}, validate_signatures=False ) - _, err, _ = await b.add_block(block1, preval[0]) + _, err, _ = await b.add_block(block1, preval[0], None) assert err is None preval = await b.pre_validate_blocks_multiprocessing([block2], {}, validate_signatures=False) - _, err, _ = await b.add_block(block2, preval[0]) + _, err, _ = await b.add_block(block2, preval[0], None) assert err is None peak = b.get_peak() @@ -3905,7 +3911,7 @@ async def test_get_tx_peak(default_400_blocks: List[FullBlock], empty_blockchain last_tx_block_record = None for b, prevalidation_res in zip(test_blocks, res): assert bc.get_tx_peak() == last_tx_block_record - _, err, _ = await bc.add_block(b, prevalidation_res) + _, err, _ = await bc.add_block(b, prevalidation_res, None) assert err is None if b.is_transaction_block(): diff --git a/chia/_tests/core/test_db_conversion.py b/chia/_tests/core/test_db_conversion.py index b47ac036e702..8eea4aaa25bf 100644 --- a/chia/_tests/core/test_db_conversion.py +++ b/chia/_tests/core/test_db_conversion.py @@ -72,7 +72,7 @@ async def test_blocks(default_1000_blocks, with_hints: bool): for block in blocks: # await _validate_and_add_block(bc, block) results = PreValidationResult(None, uint64(1), None, False, uint32(0)) - result, err, _ = await bc.add_block(block, results) + result, err, _ = await bc.add_block(block, results, None) assert err is None # now, convert v1 in_file to v2 out_file diff --git a/chia/_tests/core/test_db_validation.py b/chia/_tests/core/test_db_validation.py index b3960c20504d..a5290666f6d4 100644 --- a/chia/_tests/core/test_db_validation.py +++ b/chia/_tests/core/test_db_validation.py @@ -142,7 +142,7 @@ async def make_db(db_file: Path, blocks: List[FullBlock]) -> None: for block in blocks: results = PreValidationResult(None, uint64(1), None, False, uint32(0)) - result, err, _ = await bc.add_block(block, results) + result, err, _ = await bc.add_block(block, results, None) assert err is None diff --git a/chia/_tests/core/util/test_cached_bls.py b/chia/_tests/core/util/test_cached_bls.py index 94fdab118654..9dfc35f689de 100644 --- a/chia/_tests/core/util/test_cached_bls.py +++ b/chia/_tests/core/util/test_cached_bls.py @@ -2,9 +2,10 @@ from chia_rs import AugSchemeMPL -from chia.util import cached_bls +from chia.util.cached_bls import BLSCache from chia.util.hash import std_hash -from chia.util.lru_cache import LRUCache + +LOCAL_CACHE = BLSCache(50000) def test_cached_bls(): @@ -25,21 +26,21 @@ def test_cached_bls(): assert AugSchemeMPL.aggregate_verify(pks, msgs, agg_sig) # Verify with empty cache and populate it - assert cached_bls.aggregate_verify(pks_half, msgs_half, agg_sig_half, True) + assert LOCAL_CACHE.aggregate_verify(pks_half, msgs_half, agg_sig_half, True) # Verify with partial cache hit - assert cached_bls.aggregate_verify(pks, msgs, agg_sig, True) + assert LOCAL_CACHE.aggregate_verify(pks, msgs, agg_sig, True) # Verify with full cache hit - assert cached_bls.aggregate_verify(pks, msgs, agg_sig) + assert LOCAL_CACHE.aggregate_verify(pks, msgs, agg_sig) # Use a small cache which can not accommodate all pairings - local_cache = LRUCache(n_keys // 2) + local_cache = BLSCache(n_keys // 2) # Verify signatures and cache pairings one at a time for pk, msg, sig in zip(pks_half, msgs_half, sigs_half): - assert cached_bls.aggregate_verify([pk], [msg], sig, True, local_cache) + assert local_cache.aggregate_verify([pk], [msg], sig, True) # Verify the same messages with aggregated signature (full cache hit) - assert cached_bls.aggregate_verify(pks_half, msgs_half, agg_sig_half, False, local_cache) + assert local_cache.aggregate_verify(pks_half, msgs_half, agg_sig_half, False) # Verify more messages (partial cache hit) - assert cached_bls.aggregate_verify(pks, msgs, agg_sig, False, local_cache) + assert local_cache.aggregate_verify(pks, msgs, agg_sig, False) def test_cached_bls_repeat_pk(): @@ -54,4 +55,4 @@ def test_cached_bls_repeat_pk(): assert AugSchemeMPL.aggregate_verify(pks, msgs, agg_sig) - assert cached_bls.aggregate_verify(pks, msgs, agg_sig, force_cache=True) + assert LOCAL_CACHE.aggregate_verify(pks, msgs, agg_sig, force_cache=True) diff --git a/chia/_tests/farmer_harvester/test_third_party_harvesters.py b/chia/_tests/farmer_harvester/test_third_party_harvesters.py index c376c14e0738..956bd15048dc 100644 --- a/chia/_tests/farmer_harvester/test_third_party_harvesters.py +++ b/chia/_tests/farmer_harvester/test_third_party_harvesters.py @@ -427,7 +427,7 @@ async def add_test_blocks_into_full_node(blocks: List[FullBlock], full_node: Ful ) assert pre_validation_results is not None and len(pre_validation_results) == len(blocks) for i in range(len(blocks)): - r, _, _ = await full_node.blockchain.add_block(blocks[i], pre_validation_results[i]) + r, _, _ = await full_node.blockchain.add_block(blocks[i], pre_validation_results[i], None) assert r == AddBlockResult.NEW_PEAK diff --git a/chia/_tests/util/full_sync.py b/chia/_tests/util/full_sync.py index 843db4a57bb9..f0d4c0e66dbd 100644 --- a/chia/_tests/util/full_sync.py +++ b/chia/_tests/util/full_sync.py @@ -197,7 +197,7 @@ async def run_sync_test( if keep_up: for b in block_batch: await full_node.add_unfinished_block(make_unfinished_block(b, constants), peer) - await full_node.add_block(b) + await full_node.add_block(b, None, full_node._bls_cache) else: success, summary, _ = await full_node.add_block_batch(block_batch, peer_info, None) end_height = block_batch[-1].height diff --git a/chia/consensus/block_body_validation.py b/chia/consensus/block_body_validation.py index 03257b5a6e64..6ff2905a0ef6 100644 --- a/chia/consensus/block_body_validation.py +++ b/chia/consensus/block_body_validation.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union -from chia_rs import G1Element +from chia_rs import AugSchemeMPL, G1Element from chiabip158 import PyBIP158 from chia.consensus.block_record import BlockRecord @@ -25,7 +25,7 @@ from chia.types.full_block import FullBlock from chia.types.generator_types import BlockGenerator from chia.types.unfinished_block import UnfinishedBlock -from chia.util import cached_bls +from chia.util.cached_bls import BLSCache from chia.util.condition_tools import pkm_pairs from chia.util.errors import Err from chia.util.hash import std_hash @@ -136,6 +136,7 @@ async def validate_block_body( npc_result: Optional[NPCResult], fork_info: ForkInfo, get_block_generator: Callable[[BlockInfo], Awaitable[Optional[BlockGenerator]]], + bls_cache: Optional[BLSCache], *, validate_signature: bool = True, ) -> Tuple[Optional[Err], Optional[NPCResult]]: @@ -530,10 +531,11 @@ async def validate_block_body( # as the cache is likely to be useful when validating the corresponding # finished blocks later. if validate_signature: - force_cache: bool = isinstance(block, UnfinishedBlock) - if not cached_bls.aggregate_verify( - pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature, force_cache - ): - return Err.BAD_AGGREGATE_SIGNATURE, None + if bls_cache is None: + if not AugSchemeMPL.aggregate_verify(pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature): + return Err.BAD_AGGREGATE_SIGNATURE, None + else: + if not bls_cache.aggregate_verify(pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature): + return Err.BAD_AGGREGATE_SIGNATURE, None return None, npc_result diff --git a/chia/consensus/blockchain.py b/chia/consensus/blockchain.py index 5891d81b2cb7..7a75f1f048f6 100644 --- a/chia/consensus/blockchain.py +++ b/chia/consensus/blockchain.py @@ -45,6 +45,7 @@ from chia.types.unfinished_block import UnfinishedBlock from chia.types.unfinished_header_block import UnfinishedHeaderBlock from chia.types.weight_proof import SubEpochChallengeSegment +from chia.util.cached_bls import BLSCache from chia.util.errors import ConsensusError, Err from chia.util.generator_tools import get_block_header from chia.util.hash import std_hash @@ -290,6 +291,7 @@ async def add_block( self, block: FullBlock, pre_validation_result: PreValidationResult, + bls_cache: Optional[BLSCache], fork_info: Optional[ForkInfo] = None, ) -> Tuple[AddBlockResult, Optional[Err], Optional[StateChangeSummary]]: """ @@ -302,6 +304,9 @@ async def add_block( Args: block: The FullBlock to be validated. pre_validation_result: A result of successful pre validation + bls_cache: An optional cache of pairings that are likely to be part + of the aggregate signature. If this is set, the cache will always + be used (which may be slower if there are no cache hits). fork_info: Information about the fork chain this block is part of, to make validation more efficient. This is an in-out parameter. @@ -430,6 +435,7 @@ async def add_block( npc_result, fork_info, self.get_block_generator, + bls_cache, # If we did not already validate the signature, validate it now validate_signature=not pre_validation_result.validated_signature, ) @@ -778,6 +784,7 @@ async def validate_unfinished_block( npc_result, fork_info, self.get_block_generator, + None, validate_signature=False, # Signature was already validated before calling this method, no need to validate ) diff --git a/chia/full_node/full_node.py b/chia/full_node/full_node.py index e6a5ffeaa33d..c5e178f069ed 100644 --- a/chia/full_node/full_node.py +++ b/chia/full_node/full_node.py @@ -82,8 +82,8 @@ from chia.types.transaction_queue_entry import TransactionQueueEntry from chia.types.unfinished_block import UnfinishedBlock from chia.types.weight_proof import WeightProof -from chia.util import cached_bls from chia.util.bech32m import encode_puzzle_hash +from chia.util.cached_bls import BLSCache from chia.util.check_fork_next_block import check_fork_next_block from chia.util.condition_tools import pkm_pairs from chia.util.config import process_config_start_method @@ -166,6 +166,7 @@ class FullNode: # hashes of peaks that failed long sync on chip13 Validation bad_peak_cache: Dict[bytes32, uint32] = dataclasses.field(default_factory=dict) wallet_sync_task: Optional[asyncio.Task[None]] = None + _bls_cache: BLSCache = dataclasses.field(default_factory=lambda: BLSCache(50000)) @property def server(self) -> ChiaServer: @@ -667,6 +668,8 @@ async def short_sync_backtrack( curr_height -= 1 if found_fork_point: for block in reversed(blocks): + # when syncing, we won't share any signatures with the + # mempool, so there's no need to pass in the BLS cache. await self.add_block(block, peer) except (asyncio.CancelledError, Exception): self.sync_store.decrement_backtrack_syncing(node_id=peer.peer_node_id) @@ -1293,8 +1296,11 @@ async def add_block_batch( for i, block in enumerate(blocks_to_validate): assert pre_validation_results[i].required_iters is not None state_change_summary: Optional[StateChangeSummary] + # when adding blocks in batches, we won't have any overlapping + # signatures with the mempool. There won't be any cache hits, so + # there's no need to pass the BLS cache in result, error, state_change_summary = await self.blockchain.add_block( - block, pre_validation_results[i], fork_info + block, pre_validation_results[i], None, fork_info ) if result == AddBlockResult.NEW_PEAK: @@ -1639,6 +1645,7 @@ async def add_block( self, block: FullBlock, peer: Optional[WSChiaConnection] = None, + bls_cache: Optional[BLSCache] = None, raise_on_disconnected: bool = False, fork_info: Optional[ForkInfo] = None, ) -> Optional[Message]: @@ -1713,7 +1720,7 @@ async def add_block( f"same farmer with the same pospace." ) # This recursion ends here, we cannot recurse again because transactions_generator is not None - return await self.add_block(new_block, peer) + return await self.add_block(new_block, peer, bls_cache) state_change_summary: Optional[StateChangeSummary] = None ppp_result: Optional[PeakPostProcessingResult] = None async with self.blockchain.priority_mutex.acquire(priority=BlockchainMutexPriority.high), enable_profiler( @@ -1755,7 +1762,7 @@ async def add_block( ) assert result_to_validate.required_iters == pre_validation_results[0].required_iters (added, error_code, state_change_summary) = await self.blockchain.add_block( - block, result_to_validate, fork_info + block, result_to_validate, bls_cache, fork_info ) if added == AddBlockResult.ALREADY_HAVE_BLOCK: return None @@ -1971,7 +1978,7 @@ async def add_unfinished_block( # guaranteed to represent a successful run assert npc_result.conds is not None pairs_pks, pairs_msgs = pkm_pairs(npc_result.conds, self.constants.AGG_SIG_ME_ADDITIONAL_DATA) - if not cached_bls.aggregate_verify( + if not self._bls_cache.aggregate_verify( pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature, True ): raise ConsensusError(Err.BAD_AGGREGATE_SIGNATURE) @@ -2201,7 +2208,7 @@ async def new_infusion_point_vdf( self.log.warning("Trying to make a pre-farm block but height is not 0") return None try: - await self.add_block(block, raise_on_disconnected=True) + await self.add_block(block, None, self._bls_cache, raise_on_disconnected=True) except Exception as e: self.log.warning(f"Consensus error validating block: {e}") if timelord_peer is not None: @@ -2334,7 +2341,9 @@ async def add_transaction( self.mempool_manager.remove_seen(spend_name) else: try: - cost_result = await self.mempool_manager.pre_validate_spendbundle(transaction, tx_bytes, spend_name) + cost_result = await self.mempool_manager.pre_validate_spendbundle( + transaction, tx_bytes, spend_name, self._bls_cache + ) except ValidationError as e: self.mempool_manager.remove_seen(spend_name) return MempoolInclusionStatus.FAILED, e.code diff --git a/chia/full_node/mempool_manager.py b/chia/full_node/mempool_manager.py index 45348d6abc5c..79569bd70088 100644 --- a/chia/full_node/mempool_manager.py +++ b/chia/full_node/mempool_manager.py @@ -9,7 +9,7 @@ from multiprocessing.context import BaseContext from typing import Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple, TypeVar -from chia_rs import ELIGIBLE_FOR_DEDUP, ELIGIBLE_FOR_FF, G1Element, GTElement, supports_fast_forward +from chia_rs import ELIGIBLE_FOR_DEDUP, ELIGIBLE_FOR_FF, G1Element, supports_fast_forward from chiabip158 import PyBIP158 from chia.consensus.block_record import BlockRecordProtocol @@ -32,14 +32,12 @@ from chia.types.mempool_item import BundleCoinSpend, MempoolItem from chia.types.spend_bundle import SpendBundle from chia.types.spend_bundle_conditions import SpendBundleConditions -from chia.util import cached_bls -from chia.util.cached_bls import LOCAL_CACHE +from chia.util.cached_bls import BLSCache from chia.util.condition_tools import pkm_pairs from chia.util.db_wrapper import SQLITE_INT_MAX from chia.util.errors import Err, ValidationError from chia.util.inline_executor import InlineExecutor from chia.util.ints import uint32, uint64 -from chia.util.lru_cache import LRUCache from chia.util.setproctitle import getproctitle, setproctitle log = logging.getLogger(__name__) @@ -53,7 +51,7 @@ # the constants through here def validate_clvm_and_signature( spend_bundle_bytes: bytes, max_cost: int, constants: ConsensusConstants, height: uint32 -) -> Tuple[Optional[Err], bytes, Dict[bytes32, bytes], float]: +) -> Tuple[Optional[Err], bytes, List[Tuple[bytes32, bytes]], float]: """ Validates CLVM and aggregate signature for a spendbundle. This is meant to be called under a ProcessPoolExecutor in order to validate the heavy parts of a transaction in a different thread. Returns an optional error, @@ -72,7 +70,7 @@ def validate_clvm_and_signature( ) if result.error is not None: - return Err(result.error), b"", {}, time.monotonic() - start_time + return Err(result.error), b"", [], time.monotonic() - start_time pks: List[G1Element] = [] msgs: List[bytes] = [] @@ -80,16 +78,14 @@ def validate_clvm_and_signature( pks, msgs = pkm_pairs(result.conds, additional_data) # Verify aggregated signature - cache: LRUCache[bytes32, GTElement] = LRUCache(10000) - if not cached_bls.aggregate_verify(pks, msgs, bundle.aggregated_signature, True, cache): - return Err.BAD_AGGREGATE_SIGNATURE, b"", {}, time.monotonic() - start_time - new_cache_entries: Dict[bytes32, bytes] = {} - for k, v in cache.cache.items(): - new_cache_entries[k] = bytes(v) + cache = BLSCache(10000) + if not cache.aggregate_verify(pks, msgs, bundle.aggregated_signature, True): + return Err.BAD_AGGREGATE_SIGNATURE, b"", [], time.monotonic() - start_time + new_cache_entries: List[Tuple[bytes32, bytes]] = cache.items() except ValidationError as e: - return e.code, b"", {}, time.monotonic() - start_time + return e.code, b"", [], time.monotonic() - start_time except Exception: - return Err.UNKNOWN, b"", {}, time.monotonic() - start_time + return Err.UNKNOWN, b"", [], time.monotonic() - start_time return None, bytes(result), new_cache_entries, time.monotonic() - start_time @@ -288,7 +284,11 @@ def remove_seen(self, bundle_hash: bytes32) -> None: self.seen_bundle_hashes.pop(bundle_hash) async def pre_validate_spendbundle( - self, new_spend: SpendBundle, new_spend_bytes: Optional[bytes], spend_name: bytes32 + self, + new_spend: SpendBundle, + new_spend_bytes: Optional[bytes], + spend_name: bytes32, + bls_cache: Optional[BLSCache] = None, ) -> NPCResult: """ Errors are included within the cached_result. @@ -317,8 +317,9 @@ async def pre_validate_spendbundle( if err is not None: raise ValidationError(err) - for cache_entry_key, cached_entry_value in new_cache_entries.items(): - LOCAL_CACHE.put(cache_entry_key, GTElement.from_bytes_unchecked(cached_entry_value)) + if bls_cache is not None: + bls_cache.update(new_cache_entries) + ret: NPCResult = NPCResult.from_bytes(cached_result_bytes) log.log( logging.DEBUG if duration < 2 else logging.WARNING, diff --git a/chia/simulator/full_node_simulator.py b/chia/simulator/full_node_simulator.py index 29560e4d70d0..207ed7af3c07 100644 --- a/chia/simulator/full_node_simulator.py +++ b/chia/simulator/full_node_simulator.py @@ -174,7 +174,7 @@ async def farm_new_transaction_block( ) ) assert pre_validation_results is not None - await self.full_node.blockchain.add_block(genesis, pre_validation_results[0]) + await self.full_node.blockchain.add_block(genesis, pre_validation_results[0], self.full_node._bls_cache) peak = self.full_node.blockchain.get_peak() assert peak is not None @@ -225,7 +225,7 @@ async def farm_new_block(self, request: FarmNewBlockProtocol, force_wait_for_tim ) ) assert pre_validation_results is not None - await self.full_node.blockchain.add_block(genesis, pre_validation_results[0]) + await self.full_node.blockchain.add_block(genesis, pre_validation_results[0], self.full_node._bls_cache) peak = self.full_node.blockchain.get_peak() assert peak is not None diff --git a/chia/util/cached_bls.py b/chia/util/cached_bls.py index 1eba5c248f48..532f4977a09c 100644 --- a/chia/util/cached_bls.py +++ b/chia/util/cached_bls.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Tuple from chia_rs import AugSchemeMPL, G1Element, G2Element, GTElement @@ -10,54 +10,62 @@ from chia.util.lru_cache import LRUCache -def get_pairings( - cache: LRUCache[bytes32, GTElement], pks: List[G1Element], msgs: Sequence[bytes], force_cache: bool -) -> List[GTElement]: - pairings: List[Optional[GTElement]] = [] - missing_count: int = 0 - for pk, msg in zip(pks, msgs): - aug_msg: bytes = bytes(pk) + msg - h: bytes32 = std_hash(aug_msg) - pairing: Optional[GTElement] = cache.get(h) - if not force_cache and pairing is None: - missing_count += 1 - # Heuristic to avoid more expensive sig validation with pairing - # cache when it's empty and cached pairings won't be useful later - # (e.g. while syncing) - if missing_count > len(pks) // 2: - return [] - pairings.append(pairing) - - ret: List[GTElement] = [] - for i, pairing in enumerate(pairings): - if pairing is None: - aug_msg = bytes(pks[i]) + msgs[i] - aug_hash: G2Element = AugSchemeMPL.g2_from_message(aug_msg) - pairing = aug_hash.pair(pks[i]) - h = std_hash(aug_msg) - cache.put(h, pairing) - ret.append(pairing) - else: - ret.append(pairing) - return ret - - -# Increasing this number will increase RAM usage, but decrease BLS validation time for blocks and unfinished blocks. -LOCAL_CACHE: LRUCache[bytes32, GTElement] = LRUCache(50000) - - -def aggregate_verify( - pks: List[G1Element], - msgs: Sequence[bytes], - sig: G2Element, - force_cache: bool = False, - cache: LRUCache[bytes32, GTElement] = LOCAL_CACHE, -) -> bool: - pairings: List[GTElement] = get_pairings(cache, pks, msgs, force_cache) - if len(pairings) == 0: - # Using AugSchemeMPL.aggregate_verify, so it's safe to use from_bytes_unchecked - return AugSchemeMPL.aggregate_verify(pks, msgs, sig) - - pairings_prod: GTElement = functools.reduce(GTElement.__mul__, pairings) - res = pairings_prod == sig.pair(G1Element.generator()) - return res +class BLSCache: + cache: LRUCache[bytes32, GTElement] + + def __init__(self, size: int = 50000): + self.cache = LRUCache(size) + + def get_pairings(self, pks: List[G1Element], msgs: Sequence[bytes], force_cache: bool) -> List[GTElement]: + pairings: List[Optional[GTElement]] = [] + missing_count: int = 0 + for pk, msg in zip(pks, msgs): + aug_msg: bytes = bytes(pk) + msg + h: bytes32 = std_hash(aug_msg) + pairing: Optional[GTElement] = self.cache.get(h) + if not force_cache and pairing is None: + missing_count += 1 + # Heuristic to avoid more expensive sig validation with pairing + # cache when it's empty and cached pairings won't be useful later + # (e.g. while syncing) + if missing_count > len(pks) // 2: + return [] + pairings.append(pairing) + + # G1Element.from_bytes can be expensive due to subgroup check, so we avoid recomputing it with this cache + ret: List[GTElement] = [] + for i, pairing in enumerate(pairings): + if pairing is None: + aug_msg = bytes(pks[i]) + msgs[i] + aug_hash: G2Element = AugSchemeMPL.g2_from_message(aug_msg) + pairing = aug_hash.pair(pks[i]) + + h = std_hash(aug_msg) + self.cache.put(h, pairing) + ret.append(pairing) + else: + ret.append(pairing) + return ret + + def aggregate_verify( + self, + pks: List[G1Element], + msgs: Sequence[bytes], + sig: G2Element, + force_cache: bool = False, + ) -> bool: + pairings: List[GTElement] = self.get_pairings(pks, msgs, force_cache) + if len(pairings) == 0: + res: bool = AugSchemeMPL.aggregate_verify(pks, msgs, sig) + return res + + pairings_prod: GTElement = functools.reduce(GTElement.__mul__, pairings) + res = pairings_prod == sig.pair(G1Element.generator()) + return res + + def update(self, other: List[Tuple[bytes32, bytes]]) -> None: + for key, value in other: + self.cache.put(key, GTElement.from_bytes_unchecked(value)) + + def items(self) -> List[Tuple[bytes32, bytes]]: + return [(key, value.to_bytes()) for key, value in self.cache.cache.items()]