Skip to content

Commit

Permalink
fix(pool): only update reputation if originally accepted by pool
Browse files Browse the repository at this point in the history
  • Loading branch information
dancoombs committed Sep 5, 2023
1 parent 38e6ced commit da9b775
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 34 deletions.
5 changes: 3 additions & 2 deletions src/op_pool/mempool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ pub trait Mempool: Send + Sync {
/// Returns the best operations from the pool.
///
/// Returns the best operations from the pool based on their gas bids up to
/// the specified maximum number of operations. Will only return one operation
/// per sender.
/// the specified maximum number of operations.
///
/// NOTE: Will only return one operation per sender.
fn best_operations(&self, max: usize) -> Vec<Arc<PoolOperation>>;

/// Returns the all operations from the pool up to a max size
Expand Down
10 changes: 0 additions & 10 deletions src/op_pool/mempool/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,6 @@ impl PoolInner {
ret
}

pub fn add_operations(
&mut self,
operations: impl IntoIterator<Item = PoolOperation>,
) -> Vec<MempoolResult<H256>> {
operations
.into_iter()
.map(|op| self.add_operation(op))
.collect()
}

pub fn best_operations(&self) -> impl Iterator<Item = Arc<PoolOperation>> {
self.best.clone().into_iter().map(|v| v.po)
}
Expand Down
260 changes: 238 additions & 22 deletions src/op_pool/mempool/uo_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::{
};

use ethers::types::{Address, H256};
use itertools::Itertools;
use parking_lot::RwLock;
use tokio::sync::broadcast;
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -118,8 +119,9 @@ where
.pool
.mine_operation(op.hash, update.latest_block_number)
{
for entity in op.staked_entities() {
self.reputation.add_included(entity.address);
// Only account for a staked entity once
for entity_addr in op.staked_entities().map(|e| e.address).unique() {
self.reputation.add_included(entity_addr);
}
mined_op_count += 1;
}
Expand All @@ -130,8 +132,9 @@ where
}

if let Some(op) = state.pool.unmine_operation(op.hash) {
for entity in op.staked_entities() {
self.reputation.remove_included(entity.address);
// Only account for a staked entity once
for entity_addr in op.staked_entities().map(|e| e.address).unique() {
self.reputation.add_included(entity_addr);
}
unmined_op_count += 1;
}
Expand Down Expand Up @@ -188,6 +191,11 @@ where
let mut throttled = false;
let mut rejected_entity: Option<Entity> = None;
let mut entity_summary = EntitySummary::default();
let mut staked_entities = HashSet::new();

// Check reputation of entities in involved in the operation
// If throttled, entity can have 1 inflight operation at a time, else reject
// If banned, reject
for entity in op.entities() {
let address = entity.address;
let reputation = match self.reputation.status(address) {
Expand All @@ -206,10 +214,14 @@ where
EntityReputation::Banned
}
};
let needs_stake = op.is_staked(entity.kind);
if needs_stake {
self.reputation.add_seen(address);
}

let needs_stake = if op.is_staked(entity.kind) {
staked_entities.insert(entity.address);
true
} else {
false
};

entity_summary.set_status(
entity.kind,
EntityStatus {
Expand Down Expand Up @@ -239,11 +251,14 @@ where
}
};

// If throttled/banned, emit event and return error
if let Some(entity) = rejected_entity {
let error = MempoolError::EntityThrottled(entity);
emit_event(self.state.read().block_number, Some(error.to_string()));
return Err(MempoolError::EntityThrottled(entity));
}

// Else, add operation to pool
let mut state = self.state.write();
let bn = state.block_number;
let hash = match state.pool.add_operation(op) {
Expand All @@ -253,6 +268,12 @@ where
return Err(error);
}
};

// If successfully added, update reputation of each entity
staked_entities
.into_iter()
.for_each(|e| self.reputation.add_seen(e));
// If an entity was throttled, track with throttled ops
if throttled {
state.throttled_ops.insert(hash, bn);
}
Expand All @@ -262,10 +283,13 @@ where

fn add_operations(
&self,
_origin: OperationOrigin,
origin: OperationOrigin,
operations: impl IntoIterator<Item = PoolOperation>,
) -> Vec<MempoolResult<H256>> {
self.state.write().pool.add_operations(operations)
operations
.into_iter()
.map(|op| self.add_operation(origin, op))
.collect()
}

fn remove_operations<'a>(&self, hashes: impl IntoIterator<Item = &'a H256>) {
Expand Down Expand Up @@ -360,7 +384,13 @@ impl UoPoolMetrics {
#[cfg(test)]
mod tests {
use super::*;
use crate::{common::types::UserOperation, op_pool::chain::MinedOp};
use crate::{
common::types::{EntityType, UserOperation},
op_pool::chain::MinedOp,
};

const THROTTLE_SLACK: u64 = 5;
const BAN_SLACK: u64 = 10;

#[test]
fn add_single_op() {
Expand Down Expand Up @@ -501,6 +531,118 @@ mod tests {
check_ops(pool.best_operations(3), ops);
}

#[test]
fn test_account_reputation() {
let pool = create_pool();
let address = Address::random();
let ops = vec![
create_staked_account_op(address, 0, 2),
create_staked_account_op(address, 1, 2),
create_staked_account_op(address, 1, 2),
];
pool.add_operations(OperationOrigin::Local, ops.clone());
// Only return 1 op per sender
check_ops(pool.best_operations(3), vec![ops[0].clone()]);

let rep = pool.dump_reputation();
assert_eq!(rep.len(), 1);
assert_eq!(rep[0].address, address);
assert_eq!(rep[0].ops_seen, 2); // 2 ops seen, 1 rejected at insert
assert_eq!(rep[0].ops_included, 0); // No ops included yet

pool.on_chain_update(&ChainUpdate {
latest_block_number: 1,
latest_block_hash: H256::random(),
earliest_remembered_block_number: 0,
reorg_depth: 0,
mined_ops: vec![MinedOp {
entry_point: pool.entry_point,
hash: ops[0].uo.op_hash(pool.entry_point, 1),
sender: ops[0].uo.sender,
nonce: ops[0].uo.nonce,
}],
unmined_ops: vec![],
});

let rep = pool.dump_reputation();
assert_eq!(rep.len(), 1);
assert_eq!(rep[0].address, address);
assert_eq!(rep[0].ops_seen, 2); // 2 ops seen, 1 rejected at insert
assert_eq!(rep[0].ops_included, 1); // 1 op included
}

#[test]
fn test_throttled_account() {
let pool = create_pool();
let address = Address::random();

// Past throttle slack
pool.set_reputation(address, 1 + THROTTLE_SLACK, 0);

let ops = vec![
create_staked_account_op(address, 0, 2),
create_staked_account_op(address, 1, 2),
];

// First op should be included
pool.add_operation(OperationOrigin::Local, ops[0].clone())
.unwrap();
check_ops(pool.best_operations(1), vec![ops[0].clone()]);

// Second op should be thorottled
let ret = pool.add_operation(OperationOrigin::Local, ops[1].clone());
assert!(ret.is_err());
match ret.unwrap_err() {
MempoolError::EntityThrottled(entity) => {
assert_eq!(entity.address, address);
assert_eq!(entity.kind, EntityType::Account)
}
_ => panic!("Expected throttled error"),
}

// Mine first op
pool.on_chain_update(&ChainUpdate {
latest_block_number: 1,
latest_block_hash: H256::random(),
earliest_remembered_block_number: 0,
reorg_depth: 0,
mined_ops: vec![MinedOp {
entry_point: pool.entry_point,
hash: ops[0].uo.op_hash(pool.entry_point, 1),
sender: ops[0].uo.sender,
nonce: ops[0].uo.nonce,
}],
unmined_ops: vec![],
});

// Second op should be included
pool.add_operation(OperationOrigin::Local, ops[1].clone())
.unwrap();
check_ops(pool.best_operations(1), vec![ops[1].clone()]);
}

#[test]
fn test_banned_account() {
let pool = create_pool();
let address = Address::random();

// Past ban slack
pool.set_reputation(address, 1 + BAN_SLACK, 0);

let op = create_staked_account_op(address, 0, 2);

// First op should be banned
let ret = pool.add_operation(OperationOrigin::Local, op.clone());
assert!(ret.is_err());
match ret.unwrap_err() {
MempoolError::EntityThrottled(entity) => {
assert_eq!(entity.address, address);
assert_eq!(entity.kind, EntityType::Account)
}
_ => panic!("Expected throttled error"),
}
}

fn create_pool() -> UoPool<MockReputationManager> {
let args = PoolConfig {
entry_point: Address::random(),
Expand All @@ -512,7 +654,11 @@ mod tests {
allowlist: None,
};
let (event_sender, _) = broadcast::channel(4);
UoPool::new(args, mock_reputation(), event_sender)
UoPool::new(
args,
mock_reputation(THROTTLE_SLACK, BAN_SLACK),
event_sender,
)
}

fn create_op(sender: Address, nonce: usize, max_fee_per_gas: usize) -> PoolOperation {
Expand All @@ -527,35 +673,105 @@ mod tests {
}
}

fn create_staked_account_op(
sender: Address,
nonce: usize,
max_fee_per_gas: usize,
) -> PoolOperation {
PoolOperation {
uo: UserOperation {
sender,
nonce: nonce.into(),
max_fee_per_gas: max_fee_per_gas.into(),
..UserOperation::default()
},
account_is_staked: true,
..PoolOperation::default()
}
}

fn check_ops(ops: Vec<Arc<PoolOperation>>, expected: Vec<PoolOperation>) {
assert_eq!(ops.len(), expected.len());
for (actual, expected) in ops.into_iter().zip(expected) {
assert_eq!(actual.uo, expected.uo);
}
}

fn mock_reputation() -> Arc<MockReputationManager> {
Arc::new(MockReputationManager {})
fn mock_reputation(throttling_slack: u64, ban_slack: u64) -> Arc<MockReputationManager> {
Arc::new(MockReputationManager::new(throttling_slack, ban_slack))
}

#[derive(Default, Clone)]
struct MockReputationManager;
struct MockReputationManager {
throttling_slack: u64,
ban_slack: u64,
counts: Arc<RwLock<Counts>>,
}

#[derive(Default)]
struct Counts {
seen: HashMap<Address, u64>,
included: HashMap<Address, u64>,
}

impl MockReputationManager {
fn new(throttling_slack: u64, ban_slack: u64) -> Self {
Self {
throttling_slack,
ban_slack,
..Self::default()
}
}
}

impl ReputationManager for MockReputationManager {
fn status(&self, _address: Address) -> ReputationStatus {
ReputationStatus::Ok
fn status(&self, address: Address) -> ReputationStatus {
let counts = self.counts.read();

let seen = *counts.seen.get(&address).unwrap_or(&0);
let included = *counts.included.get(&address).unwrap_or(&0);
let diff = seen.saturating_sub(included);
if diff > self.ban_slack {
ReputationStatus::Banned
} else if diff > self.throttling_slack {
ReputationStatus::Throttled
} else {
ReputationStatus::Ok
}
}

fn add_seen(&self, _address: Address) {}
fn add_seen(&self, address: Address) {
*self.counts.write().seen.entry(address).or_default() += 1;
}

fn add_included(&self, _address: Address) {}
fn add_included(&self, address: Address) {
*self.counts.write().included.entry(address).or_default() += 1;
}

fn remove_included(&self, _address: Address) {}
fn remove_included(&self, address: Address) {
let mut counts = self.counts.write();
let included = counts.included.entry(address).or_default();
*included = included.saturating_sub(1);
}

fn dump_reputation(&self) -> Vec<Reputation> {
vec![]
self.counts
.read()
.seen
.iter()
.map(|(address, ops_seen)| Reputation {
address: *address,
ops_seen: *ops_seen,
ops_included: *self.counts.read().included.get(address).unwrap_or(&0),
status: self.status(*address),
})
.collect()
}

fn set_reputation(&self, _address: Address, _ops_seen: u64, _ops_included: u64) {}
fn set_reputation(&self, address: Address, ops_seen: u64, ops_included: u64) {
let mut counts = self.counts.write();
counts.seen.insert(address, ops_seen);
counts.included.insert(address, ops_included);
}
}
}

0 comments on commit da9b775

Please sign in to comment.