Skip to content

Commit

Permalink
[CHIA-615] make BLSCache a proper class (#18053)
Browse files Browse the repository at this point in the history
* make BlsCache a proper class, rather than a global dict with helper functions

* make the BLS cache a member of the FullNode, rather than a global variable

* add test for BAD_AGGREGATE_SIGNATURE when using the BLS cache
  • Loading branch information
arvidn authored May 31, 2024
1 parent a83b240 commit 605e3b8
Show file tree
Hide file tree
Showing 13 changed files with 157 additions and 114 deletions.
13 changes: 11 additions & 2 deletions chia/_tests/blockchain/blockchain_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from chia.consensus.blockchain import AddBlockResult, Blockchain
from chia.consensus.multiprocess_validation import PreValidationResult
from chia.types.full_block import FullBlock
from chia.util.cached_bls import BLSCache
from chia.util.errors import Err
from chia.util.ints import uint32, uint64

Expand Down Expand Up @@ -42,10 +43,12 @@ async def check_block_store_invariant(bc: Blockchain):
async def _validate_and_add_block(
blockchain: Blockchain,
block: FullBlock,
*,
expected_result: Optional[AddBlockResult] = None,
expected_error: Optional[Err] = None,
skip_prevalidation: bool = False,
fork_info: Optional[ForkInfo] = None,
use_bls_cache: bool = False,
) -> None:
# Tries to validate and add the block, and checks that there are no errors in the process and that the
# block is added to the peak.
Expand All @@ -58,7 +61,8 @@ async def _validate_and_add_block(
if skip_prevalidation:
results = PreValidationResult(None, uint64(1), None, False, uint32(0))
else:
# Do not change this, validate_signatures must be False
# validate_signatures must be False in order to trigger add_block() to
# validate the signature.
pre_validation_results: List[PreValidationResult] = await blockchain.pre_validate_blocks_multiprocessing(
[block], {}, validate_signatures=False
)
Expand All @@ -78,11 +82,16 @@ async def _validate_and_add_block(
await check_block_store_invariant(blockchain)
return None

if use_bls_cache:
bls_cache = BLSCache(100)
else:
bls_cache = None

(
result,
err,
_,
) = await blockchain.add_block(block, results, fork_info=fork_info)
) = await blockchain.add_block(block, results, bls_cache, fork_info=fork_info)
await check_block_store_invariant(blockchain)

if expected_error is None and expected_result != AddBlockResult.INVALID_BLOCK:
Expand Down
32 changes: 19 additions & 13 deletions chia/_tests/blockchain/test_blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ async def test_pre_validation(
assert res[n].error is None
block = blocks_to_validate[n]
start_rb = time.time()
result, err, _ = await empty_blockchain.add_block(block, res[n])
result, err, _ = await empty_blockchain.add_block(block, res[n], None)
end_rb = time.time()
times_rb.append(end_rb - start_rb)
assert err is None
Expand Down Expand Up @@ -1934,7 +1934,7 @@ async def test_conditions(
)
# Ignore errors from pre-validation, we are testing block_body_validation
repl_preval_results = replace(pre_validation_results[0], error=None, required_iters=uint64(1))
code, err, state_change = await b.add_block(blocks[-1], repl_preval_results)
code, err, state_change = await b.add_block(blocks[-1], repl_preval_results, None)
assert code == AddBlockResult.NEW_PEAK
assert err is None
assert state_change is not None
Expand Down Expand Up @@ -2050,7 +2050,7 @@ async def test_timelock_conditions(
[blocks[-1]], {}, validate_signatures=True
)
assert pre_validation_results is not None
assert (await b.add_block(blocks[-1], pre_validation_results[0]))[0] == expected
assert (await b.add_block(blocks[-1], pre_validation_results[0], None))[0] == expected

if expected == AddBlockResult.NEW_PEAK:
# ensure coin was in fact spent
Expand Down Expand Up @@ -2152,7 +2152,7 @@ async def test_aggsig_garbage(
)
# Ignore errors from pre-validation, we are testing block_body_validation
repl_preval_results = replace(pre_validation_results[0], error=None, required_iters=uint64(1))
res, error, state_change = await b.add_block(blocks[-1], repl_preval_results)
res, error, state_change = await b.add_block(blocks[-1], repl_preval_results, None)
assert (res, error, state_change.fork_height if state_change else None) == expected

@pytest.mark.anyio
Expand Down Expand Up @@ -2268,7 +2268,7 @@ async def test_ephemeral_timelock(
[blocks[-1]], {}, validate_signatures=True
)
assert pre_validation_results is not None
assert (await b.add_block(blocks[-1], pre_validation_results[0]))[0] == expected
assert (await b.add_block(blocks[-1], pre_validation_results[0], None))[0] == expected

if expected == AddBlockResult.NEW_PEAK:
# ensure coin1 was in fact spent
Expand Down Expand Up @@ -2657,7 +2657,9 @@ async def test_cost_exceeds_max(
height=softfork_height,
constants=bt.constants,
)
err = (await b.add_block(blocks[-1], PreValidationResult(None, uint64(1), npc_result, True, uint32(0))))[1]
err = (await b.add_block(blocks[-1], PreValidationResult(None, uint64(1), npc_result, True, uint32(0)), None))[
1
]
assert err in [Err.BLOCK_COST_EXCEEDS_MAX]

results: List[PreValidationResult] = await b.pre_validate_blocks_multiprocessing(
Expand Down Expand Up @@ -2722,7 +2724,7 @@ async def test_invalid_cost_in_block(
height=softfork_height,
constants=bt.constants,
)
_, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)))
_, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)), None)
assert err == Err.INVALID_BLOCK_COST

# too low
Expand All @@ -2749,7 +2751,7 @@ async def test_invalid_cost_in_block(
height=softfork_height,
constants=bt.constants,
)
_, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)))
_, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)), None)
assert err == Err.INVALID_BLOCK_COST

# too high
Expand Down Expand Up @@ -2778,7 +2780,9 @@ async def test_invalid_cost_in_block(
block_generator, max_cost, mempool_mode=False, height=softfork_height, constants=bt.constants
)

result, err, _ = await b.add_block(block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)))
result, err, _ = await b.add_block(
block_2, PreValidationResult(None, uint64(1), npc_result, False, uint32(0)), None
)
assert err == Err.INVALID_BLOCK_COST

# when the CLVM program exceeds cost during execution, it will fail with
Expand Down Expand Up @@ -3220,6 +3224,8 @@ async def test_invalid_agg_sig(self, empty_blockchain: Blockchain, bt: BlockTool

# Bad signature fails during add_block
await _validate_and_add_block(b, last_block, expected_error=Err.BAD_AGGREGATE_SIGNATURE)
# Also test the same case but when using BLSCache
await _validate_and_add_block(b, last_block, expected_error=Err.BAD_AGGREGATE_SIGNATURE, use_bls_cache=True)

# Bad signature also fails in prevalidation
preval_results = await b.pre_validate_blocks_multiprocessing([last_block], {}, validate_signatures=True)
Expand Down Expand Up @@ -3336,7 +3342,7 @@ async def test_long_reorg(
assert pre_validation_results[i].error is None
if (block.height % 100) == 0:
print(f"main chain: {block.height:4} weight: {block.weight}")
(result, err, _) = await b.add_block(block, pre_validation_results[i])
(result, err, _) = await b.add_block(block, pre_validation_results[i], None)
await check_block_store_invariant(b)
assert err is None
assert result == AddBlockResult.NEW_PEAK
Expand Down Expand Up @@ -3874,10 +3880,10 @@ async def test_reorg_flip_flop(empty_blockchain: Blockchain, bt: BlockTools) ->
preval: List[PreValidationResult] = await b.pre_validate_blocks_multiprocessing(
[block1], {}, validate_signatures=False
)
_, err, _ = await b.add_block(block1, preval[0])
_, err, _ = await b.add_block(block1, preval[0], None)
assert err is None
preval = await b.pre_validate_blocks_multiprocessing([block2], {}, validate_signatures=False)
_, err, _ = await b.add_block(block2, preval[0])
_, err, _ = await b.add_block(block2, preval[0], None)
assert err is None

peak = b.get_peak()
Expand Down Expand Up @@ -3905,7 +3911,7 @@ async def test_get_tx_peak(default_400_blocks: List[FullBlock], empty_blockchain
last_tx_block_record = None
for b, prevalidation_res in zip(test_blocks, res):
assert bc.get_tx_peak() == last_tx_block_record
_, err, _ = await bc.add_block(b, prevalidation_res)
_, err, _ = await bc.add_block(b, prevalidation_res, None)
assert err is None

if b.is_transaction_block():
Expand Down
2 changes: 1 addition & 1 deletion chia/_tests/core/test_db_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def test_blocks(default_1000_blocks, with_hints: bool):
for block in blocks:
# await _validate_and_add_block(bc, block)
results = PreValidationResult(None, uint64(1), None, False, uint32(0))
result, err, _ = await bc.add_block(block, results)
result, err, _ = await bc.add_block(block, results, None)
assert err is None

# now, convert v1 in_file to v2 out_file
Expand Down
2 changes: 1 addition & 1 deletion chia/_tests/core/test_db_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def make_db(db_file: Path, blocks: List[FullBlock]) -> None:

for block in blocks:
results = PreValidationResult(None, uint64(1), None, False, uint32(0))
result, err, _ = await bc.add_block(block, results)
result, err, _ = await bc.add_block(block, results, None)
assert err is None


Expand Down
21 changes: 11 additions & 10 deletions chia/_tests/core/util/test_cached_bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from chia_rs import AugSchemeMPL

from chia.util import cached_bls
from chia.util.cached_bls import BLSCache
from chia.util.hash import std_hash
from chia.util.lru_cache import LRUCache

LOCAL_CACHE = BLSCache(50000)


def test_cached_bls():
Expand All @@ -25,21 +26,21 @@ def test_cached_bls():
assert AugSchemeMPL.aggregate_verify(pks, msgs, agg_sig)

# Verify with empty cache and populate it
assert cached_bls.aggregate_verify(pks_half, msgs_half, agg_sig_half, True)
assert LOCAL_CACHE.aggregate_verify(pks_half, msgs_half, agg_sig_half, True)
# Verify with partial cache hit
assert cached_bls.aggregate_verify(pks, msgs, agg_sig, True)
assert LOCAL_CACHE.aggregate_verify(pks, msgs, agg_sig, True)
# Verify with full cache hit
assert cached_bls.aggregate_verify(pks, msgs, agg_sig)
assert LOCAL_CACHE.aggregate_verify(pks, msgs, agg_sig)

# Use a small cache which can not accommodate all pairings
local_cache = LRUCache(n_keys // 2)
local_cache = BLSCache(n_keys // 2)
# Verify signatures and cache pairings one at a time
for pk, msg, sig in zip(pks_half, msgs_half, sigs_half):
assert cached_bls.aggregate_verify([pk], [msg], sig, True, local_cache)
assert local_cache.aggregate_verify([pk], [msg], sig, True)
# Verify the same messages with aggregated signature (full cache hit)
assert cached_bls.aggregate_verify(pks_half, msgs_half, agg_sig_half, False, local_cache)
assert local_cache.aggregate_verify(pks_half, msgs_half, agg_sig_half, False)
# Verify more messages (partial cache hit)
assert cached_bls.aggregate_verify(pks, msgs, agg_sig, False, local_cache)
assert local_cache.aggregate_verify(pks, msgs, agg_sig, False)


def test_cached_bls_repeat_pk():
Expand All @@ -54,4 +55,4 @@ def test_cached_bls_repeat_pk():

assert AugSchemeMPL.aggregate_verify(pks, msgs, agg_sig)

assert cached_bls.aggregate_verify(pks, msgs, agg_sig, force_cache=True)
assert LOCAL_CACHE.aggregate_verify(pks, msgs, agg_sig, force_cache=True)
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ async def add_test_blocks_into_full_node(blocks: List[FullBlock], full_node: Ful
)
assert pre_validation_results is not None and len(pre_validation_results) == len(blocks)
for i in range(len(blocks)):
r, _, _ = await full_node.blockchain.add_block(blocks[i], pre_validation_results[i])
r, _, _ = await full_node.blockchain.add_block(blocks[i], pre_validation_results[i], None)
assert r == AddBlockResult.NEW_PEAK


Expand Down
2 changes: 1 addition & 1 deletion chia/_tests/util/full_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def run_sync_test(
if keep_up:
for b in block_batch:
await full_node.add_unfinished_block(make_unfinished_block(b, constants), peer)
await full_node.add_block(b)
await full_node.add_block(b, None, full_node._bls_cache)
else:
success, summary, _ = await full_node.add_block_batch(block_batch, peer_info, None)
end_height = block_batch[-1].height
Expand Down
16 changes: 9 additions & 7 deletions chia/consensus/block_body_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union

from chia_rs import G1Element
from chia_rs import AugSchemeMPL, G1Element
from chiabip158 import PyBIP158

from chia.consensus.block_record import BlockRecord
Expand All @@ -25,7 +25,7 @@
from chia.types.full_block import FullBlock
from chia.types.generator_types import BlockGenerator
from chia.types.unfinished_block import UnfinishedBlock
from chia.util import cached_bls
from chia.util.cached_bls import BLSCache
from chia.util.condition_tools import pkm_pairs
from chia.util.errors import Err
from chia.util.hash import std_hash
Expand Down Expand Up @@ -136,6 +136,7 @@ async def validate_block_body(
npc_result: Optional[NPCResult],
fork_info: ForkInfo,
get_block_generator: Callable[[BlockInfo], Awaitable[Optional[BlockGenerator]]],
bls_cache: Optional[BLSCache],
*,
validate_signature: bool = True,
) -> Tuple[Optional[Err], Optional[NPCResult]]:
Expand Down Expand Up @@ -530,10 +531,11 @@ async def validate_block_body(
# as the cache is likely to be useful when validating the corresponding
# finished blocks later.
if validate_signature:
force_cache: bool = isinstance(block, UnfinishedBlock)
if not cached_bls.aggregate_verify(
pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature, force_cache
):
return Err.BAD_AGGREGATE_SIGNATURE, None
if bls_cache is None:
if not AugSchemeMPL.aggregate_verify(pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature):
return Err.BAD_AGGREGATE_SIGNATURE, None
else:
if not bls_cache.aggregate_verify(pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature):
return Err.BAD_AGGREGATE_SIGNATURE, None

return None, npc_result
7 changes: 7 additions & 0 deletions chia/consensus/blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from chia.types.unfinished_block import UnfinishedBlock
from chia.types.unfinished_header_block import UnfinishedHeaderBlock
from chia.types.weight_proof import SubEpochChallengeSegment
from chia.util.cached_bls import BLSCache
from chia.util.errors import ConsensusError, Err
from chia.util.generator_tools import get_block_header
from chia.util.hash import std_hash
Expand Down Expand Up @@ -290,6 +291,7 @@ async def add_block(
self,
block: FullBlock,
pre_validation_result: PreValidationResult,
bls_cache: Optional[BLSCache],
fork_info: Optional[ForkInfo] = None,
) -> Tuple[AddBlockResult, Optional[Err], Optional[StateChangeSummary]]:
"""
Expand All @@ -302,6 +304,9 @@ async def add_block(
Args:
block: The FullBlock to be validated.
pre_validation_result: A result of successful pre validation
bls_cache: An optional cache of pairings that are likely to be part
of the aggregate signature. If this is set, the cache will always
be used (which may be slower if there are no cache hits).
fork_info: Information about the fork chain this block is part of,
to make validation more efficient. This is an in-out parameter.
Expand Down Expand Up @@ -430,6 +435,7 @@ async def add_block(
npc_result,
fork_info,
self.get_block_generator,
bls_cache,
# If we did not already validate the signature, validate it now
validate_signature=not pre_validation_result.validated_signature,
)
Expand Down Expand Up @@ -778,6 +784,7 @@ async def validate_unfinished_block(
npc_result,
fork_info,
self.get_block_generator,
None,
validate_signature=False, # Signature was already validated before calling this method, no need to validate
)

Expand Down
Loading

0 comments on commit 605e3b8

Please sign in to comment.