Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi: [db] simplify peer db, [node] add peer manager #6

Merged
merged 1 commit into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/db/error.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
use thiserror::Error;

#[derive(Error, Debug)]
pub enum HeaderDatabaseError {
#[error("loading a query or data from sqlite failed")]
pub enum DatabaseError {
#[error("loading a query or data from the database failed")]
LoadError,
#[error("writing a query or data from sqlite failed")]
#[error("writing a query or data from the database failed")]
WriteError,
}

#[derive(Error, Debug)]
pub enum PeerManagerError {
#[error("DNS failed to respond")]
Dns,
#[error("reading or writing from the database failed")]
Database(DatabaseError),
}
26 changes: 26 additions & 0 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
use std::net::IpAddr;

use bitcoin::p2p::ServiceFlags;

pub(crate) mod error;
pub(crate) mod peer_man;
pub(crate) mod sqlite;
pub(crate) mod traits;

#[derive(Debug, Clone)]
pub struct PersistedPeer {
pub addr: IpAddr,
pub port: u16,
pub services: ServiceFlags,
pub tried: bool,
pub banned: bool,
}

impl PersistedPeer {
pub fn new(addr: IpAddr, port: u16, services: ServiceFlags, tried: bool, banned: bool) -> Self {
Self {
addr,
port,
services,
tried,
banned,
}
}
}
131 changes: 131 additions & 0 deletions src/db/peer_man.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use std::{collections::HashSet, net::IpAddr, sync::Arc};

use bitcoin::{p2p::ServiceFlags, Network};
use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng};
use tokio::sync::Mutex;

use crate::{
peers::dns::Dns,
prelude::{default_port_from_network, SlashSixteen},
};

use super::{error::PeerManagerError, traits::PeerStore, PersistedPeer};

#[derive(Debug, Clone)]
pub(crate) struct PeerManager {
db: Arc<Mutex<dyn PeerStore + Send + Sync>>,
netgroups: HashSet<String>,
network: Network,
default_port: u16,
}

impl PeerManager {
pub(crate) fn new(db: impl PeerStore + Send + Sync + 'static, network: &Network) -> Self {
let default_port = default_port_from_network(network);
Self {
db: Arc::new(Mutex::new(db)),
netgroups: HashSet::new(),
network: *network,
default_port,
}
}

pub(crate) async fn next_peer(&mut self) -> Result<(IpAddr, u16), PeerManagerError> {
let mut db_lock = self.db.lock().await;
let mut tries = 0;
while tries < 10 {
let mut next = db_lock.random().await.map_err(PeerManagerError::Database)?;
if !self.netgroups.contains(&next.addr.slash_sixteen()) {
self.netgroups.insert(next.addr.slash_sixteen());
return Ok((next.addr, next.port));
}
tries += 1;
}
let mut next = db_lock.random().await.map_err(PeerManagerError::Database)?;
self.netgroups.insert(next.addr.slash_sixteen());
Ok((next.addr, next.port))
}

pub(crate) async fn bootstrap(&mut self) -> Result<(), PeerManagerError> {
let mut db_lock = self.db.lock().await;
let mut new_peers = Dns::bootstrap(self.network)
.await
.map_err(|_| PeerManagerError::Dns)?;
let mut rng = StdRng::from_entropy();
new_peers.shuffle(&mut rng);
// DNS fails if there is an insufficient number of peers
for peer in new_peers {
db_lock
.update(PersistedPeer::new(
peer,
self.default_port,
ServiceFlags::NONE,
false,
false,
))
.await
.map_err(PeerManagerError::Database)?;
}
Ok(())
}

pub(crate) async fn peer_count(&mut self) -> Result<u32, PeerManagerError> {
let mut db_lock = self.db.lock().await;
db_lock
.num_unbanned()
.await
.map_err(PeerManagerError::Database)
}

pub(crate) async fn add_new_peer(
&mut self,
addr: IpAddr,
port: Option<u16>,
services: Option<ServiceFlags>,
) -> Result<(), PeerManagerError> {
self.internal_db_update(addr, port, services, false, false)
.await
}

pub(crate) async fn tried_peer(
&mut self,
addr: IpAddr,
port: Option<u16>,
services: Option<ServiceFlags>,
) -> Result<(), PeerManagerError> {
self.internal_db_update(addr, port, services, true, false)
.await
}

pub(crate) async fn ban_peer(
&mut self,
addr: IpAddr,
port: Option<u16>,
services: Option<ServiceFlags>,
) -> Result<(), PeerManagerError> {
self.internal_db_update(addr, port, services, true, true)
.await
}

async fn internal_db_update(
&mut self,
addr: IpAddr,
port: Option<u16>,
services: Option<ServiceFlags>,
tried: bool,
ban: bool,
) -> Result<(), PeerManagerError> {
let mut db_lock = self.db.lock().await;
db_lock
.update(PersistedPeer::new(
addr,
port.unwrap_or(self.default_port),
services.unwrap_or(ServiceFlags::NONE),
tried,
ban,
))
.await
.map_err(PeerManagerError::Database)?;
Ok(())
}
}
61 changes: 28 additions & 33 deletions src/db/sqlite/header_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use bitcoin::{BlockHash, CompactTarget, Network, TxMerkleNode};
use rusqlite::{params, Connection, Result};
use tokio::sync::Mutex;

use crate::db::error::HeaderDatabaseError;
use crate::db::error::DatabaseError;
use crate::db::traits::HeaderStore;

const SCHEMA: &str = "CREATE TABLE IF NOT EXISTS headers (
Expand All @@ -31,17 +31,17 @@ pub(crate) struct SqliteHeaderDb {
}

impl SqliteHeaderDb {
pub fn new(network: Network, path: Option<PathBuf>) -> Result<Self, HeaderDatabaseError> {
pub fn new(network: Network, path: Option<PathBuf>) -> Result<Self, DatabaseError> {
let mut path = path.unwrap_or_else(|| PathBuf::from("."));
path.push("data");
path.push(network.to_string());
if !path.exists() {
fs::create_dir_all(&path).unwrap();
}
let conn = Connection::open(path.join("headers.db"))
.map_err(|_| HeaderDatabaseError::LoadError)?;
let conn =
Connection::open(path.join("headers.db")).map_err(|_| DatabaseError::LoadError)?;
conn.execute(SCHEMA, [])
.map_err(|_| HeaderDatabaseError::LoadError)?;
.map_err(|_| DatabaseError::LoadError)?;
Ok(Self {
network,
conn: Arc::new(Mutex::new(conn)),
Expand All @@ -52,32 +52,27 @@ impl SqliteHeaderDb {
#[async_trait]
impl HeaderStore for SqliteHeaderDb {
// load all the known headers from storage
async fn load(
&mut self,
anchor_height: u32,
) -> Result<BTreeMap<u32, Header>, HeaderDatabaseError> {
async fn load(&mut self, anchor_height: u32) -> Result<BTreeMap<u32, Header>, DatabaseError> {
let mut headers = BTreeMap::<u32, Header>::new();
let stmt = "SELECT * FROM headers ORDER BY height";
let write_lock = self.conn.lock().await;
let mut query = write_lock
.prepare(stmt)
.map_err(|_| HeaderDatabaseError::LoadError)?;
let mut rows = query
.query([])
.map_err(|_| HeaderDatabaseError::LoadError)?;
while let Some(row) = rows.next().map_err(|_| HeaderDatabaseError::LoadError)? {
let height: u32 = row.get(0).map_err(|_| HeaderDatabaseError::LoadError)?;
.map_err(|_| DatabaseError::LoadError)?;
let mut rows = query.query([]).map_err(|_| DatabaseError::LoadError)?;
while let Some(row) = rows.next().map_err(|_| DatabaseError::LoadError)? {
let height: u32 = row.get(0).map_err(|_| DatabaseError::LoadError)?;
// The anchor height should not be included in the chain, as the anchor is non-inclusive
if height.le(&anchor_height) {
continue;
}
let hash: String = row.get(1).map_err(|_| HeaderDatabaseError::LoadError)?;
let version: i32 = row.get(2).map_err(|_| HeaderDatabaseError::LoadError)?;
let prev_hash: String = row.get(3).map_err(|_| HeaderDatabaseError::LoadError)?;
let merkle_root: String = row.get(4).map_err(|_| HeaderDatabaseError::LoadError)?;
let time: u32 = row.get(5).map_err(|_| HeaderDatabaseError::LoadError)?;
let bits: u32 = row.get(6).map_err(|_| HeaderDatabaseError::LoadError)?;
let nonce: u32 = row.get(7).map_err(|_| HeaderDatabaseError::LoadError)?;
let hash: String = row.get(1).map_err(|_| DatabaseError::LoadError)?;
let version: i32 = row.get(2).map_err(|_| DatabaseError::LoadError)?;
let prev_hash: String = row.get(3).map_err(|_| DatabaseError::LoadError)?;
let merkle_root: String = row.get(4).map_err(|_| DatabaseError::LoadError)?;
let time: u32 = row.get(5).map_err(|_| DatabaseError::LoadError)?;
let bits: u32 = row.get(6).map_err(|_| DatabaseError::LoadError)?;
let nonce: u32 = row.get(7).map_err(|_| DatabaseError::LoadError)?;

let next_header = Header {
version: Version::from_consensus(version),
Expand Down Expand Up @@ -109,14 +104,14 @@ impl HeaderStore for SqliteHeaderDb {
async fn write<'a>(
&mut self,
header_chain: &'a BTreeMap<u32, Header>,
) -> Result<(), HeaderDatabaseError> {
) -> Result<(), DatabaseError> {
let mut write_lock = self.conn.lock().await;
let tx = write_lock
.transaction()
.map_err(|_| HeaderDatabaseError::WriteError)?;
.map_err(|_| DatabaseError::WriteError)?;
let best_height: Option<u32> = tx
.query_row("SELECT MAX(height) FROM headers", [], |row| row.get(0))
.map_err(|_| HeaderDatabaseError::WriteError)?;
.map_err(|_| DatabaseError::WriteError)?;
for (height, header) in header_chain {
if height.ge(&(best_height.unwrap_or(0))) {
let hash: String = header.block_hash().to_string();
Expand All @@ -140,22 +135,22 @@ impl HeaderStore for SqliteHeaderDb {
nonce
],
)
.map_err(|_| HeaderDatabaseError::WriteError)?;
.map_err(|_| DatabaseError::WriteError)?;
}
}
tx.commit().map_err(|_| HeaderDatabaseError::WriteError)?;
tx.commit().map_err(|_| DatabaseError::WriteError)?;
Ok(())
}

async fn write_over<'a>(
&mut self,
header_chain: &'a BTreeMap<u32, Header>,
height: u32,
) -> Result<(), HeaderDatabaseError> {
) -> Result<(), DatabaseError> {
let mut write_lock = self.conn.lock().await;
let tx = write_lock
.transaction()
.map_err(|_| HeaderDatabaseError::WriteError)?;
.map_err(|_| DatabaseError::WriteError)?;
for (h, header) in header_chain {
if h.ge(&height) {
let hash: String = header.block_hash().to_string();
Expand All @@ -179,22 +174,22 @@ impl HeaderStore for SqliteHeaderDb {
nonce
],
)
.map_err(|_| HeaderDatabaseError::WriteError)?;
.map_err(|_| DatabaseError::WriteError)?;
}
}
tx.commit().map_err(|_| HeaderDatabaseError::WriteError)?;
tx.commit().map_err(|_| DatabaseError::WriteError)?;
Ok(())
}

async fn height_of<'a>(
&mut self,
block_hash: &'a BlockHash,
) -> Result<Option<u32>, HeaderDatabaseError> {
) -> Result<Option<u32>, DatabaseError> {
let write_lock = self.conn.lock().await;
let stmt = "SELECT height FROM headers WHERE block_hash = ?1";
let row: Option<u32> = write_lock
.query_row(stmt, params![block_hash.to_string()], |row| row.get(0))
.map_err(|_| HeaderDatabaseError::LoadError)?;
.map_err(|_| DatabaseError::LoadError)?;
Ok(row)
}
}
Loading
Loading