-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #627 from blockscout/lok52/launcher-test-utils
Add launcher test utils
- Loading branch information
Showing
6 changed files
with
253 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
use crate::database::{ | ||
ConnectionTrait, Database, DatabaseConnection, DbErr, MigratorTrait, Statement, | ||
}; | ||
use std::{ops::Deref, sync::Arc}; | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct TestDbGuard { | ||
conn_with_db: Arc<DatabaseConnection>, | ||
conn_without_db: Arc<DatabaseConnection>, | ||
base_db_url: String, | ||
db_name: String, | ||
} | ||
|
||
impl TestDbGuard { | ||
pub async fn new<Migrator: MigratorTrait>(db_name: &str) -> Self { | ||
let base_db_url = std::env::var("DATABASE_URL") | ||
.expect("Database url must be set to initialize a test database"); | ||
let conn_without_db = Database::connect(&base_db_url) | ||
.await | ||
.expect("Connection to postgres (without database) failed"); | ||
// We use a hash, as the name itself may be quite long and be trimmed. | ||
let db_name = format!("_{:x}", keccak_hash::keccak(db_name)); | ||
let mut guard = TestDbGuard { | ||
conn_with_db: Arc::new(DatabaseConnection::Disconnected), | ||
conn_without_db: Arc::new(conn_without_db), | ||
base_db_url, | ||
db_name, | ||
}; | ||
|
||
guard.init_database().await; | ||
guard.run_migrations::<Migrator>().await; | ||
guard | ||
} | ||
|
||
pub fn client(&self) -> Arc<DatabaseConnection> { | ||
self.conn_with_db.clone() | ||
} | ||
|
||
pub fn db_url(&self) -> String { | ||
format!("{}/{}", self.base_db_url, self.db_name) | ||
} | ||
|
||
async fn init_database(&mut self) { | ||
// Create database | ||
self.drop_database().await; | ||
self.create_database().await; | ||
|
||
let db_url = self.db_url(); | ||
let conn_with_db = Database::connect(&db_url) | ||
.await | ||
.expect("Connection to postgres (with database) failed"); | ||
self.conn_with_db = Arc::new(conn_with_db); | ||
} | ||
|
||
pub async fn drop_database(&self) { | ||
Self::drop_database_internal(&self.conn_without_db, &self.db_name) | ||
.await | ||
.expect("Database drop failed"); | ||
} | ||
|
||
async fn create_database(&self) { | ||
Self::create_database_internal(&self.conn_without_db, &self.db_name) | ||
.await | ||
.expect("Database creation failed"); | ||
} | ||
|
||
async fn create_database_internal(db: &DatabaseConnection, db_name: &str) -> Result<(), DbErr> { | ||
tracing::info!(name = db_name, "creating database"); | ||
db.execute(Statement::from_string( | ||
db.get_database_backend(), | ||
format!("CREATE DATABASE {db_name}"), | ||
)) | ||
.await?; | ||
Ok(()) | ||
} | ||
|
||
async fn drop_database_internal(db: &DatabaseConnection, db_name: &str) -> Result<(), DbErr> { | ||
tracing::info!(name = db_name, "dropping database"); | ||
db.execute(Statement::from_string( | ||
db.get_database_backend(), | ||
format!("DROP DATABASE IF EXISTS {db_name} WITH (FORCE)"), | ||
)) | ||
.await?; | ||
Ok(()) | ||
} | ||
|
||
async fn run_migrations<Migrator: MigratorTrait>(&self) { | ||
Migrator::up(self.conn_with_db.as_ref(), None) | ||
.await | ||
.expect("Database migration failed"); | ||
} | ||
} | ||
|
||
impl Deref for TestDbGuard { | ||
type Target = DatabaseConnection; | ||
fn deref(&self) -> &Self::Target { | ||
&self.conn_with_db | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
use crate::launcher::ServerSettings; | ||
use reqwest::Url; | ||
use std::{ | ||
future::Future, | ||
net::{SocketAddr, TcpListener}, | ||
str::FromStr, | ||
time::Duration, | ||
}; | ||
use tokio::time::timeout; | ||
|
||
fn get_free_port() -> u16 { | ||
let listener = TcpListener::bind("127.0.0.1:0").unwrap(); | ||
listener.local_addr().unwrap().port() | ||
} | ||
|
||
pub fn get_test_server_settings() -> (ServerSettings, Url) { | ||
let mut server = ServerSettings::default(); | ||
let port = get_free_port(); | ||
server.http.addr = SocketAddr::from_str(&format!("127.0.0.1:{port}")).unwrap(); | ||
server.grpc.enabled = false; | ||
let base = Url::parse(&format!("http://{}", server.http.addr)).unwrap(); | ||
(server, base) | ||
} | ||
|
||
pub async fn init_server<F, R>(run: F, base: &Url) | ||
where | ||
F: FnOnce() -> R + Send + 'static, | ||
R: Future<Output = Result<(), anyhow::Error>> + Send, | ||
{ | ||
tokio::spawn(async move { run().await }); | ||
|
||
let client = reqwest::Client::new(); | ||
let health_endpoint = base.join("health").unwrap(); | ||
|
||
let wait_health_check = async { | ||
loop { | ||
if let Ok(_response) = client | ||
.get(health_endpoint.clone()) | ||
.query(&[("service", "")]) | ||
.send() | ||
.await | ||
{ | ||
break; | ||
} | ||
} | ||
}; | ||
// Wait for the server to start | ||
if (timeout(Duration::from_secs(10), wait_health_check).await).is_err() { | ||
panic!("Server did not start in time"); | ||
} | ||
} | ||
|
||
async fn send_annotated_request<Response: for<'a> serde::Deserialize<'a>>( | ||
url: &Url, | ||
route: &str, | ||
method: reqwest::Method, | ||
payload: Option<&impl serde::Serialize>, | ||
annotation: Option<&str>, | ||
) -> Response { | ||
let annotation = annotation.map(|v| format!("({v}) ")).unwrap_or_default(); | ||
|
||
let mut request = reqwest::Client::new().request(method, url.join(route).unwrap()); | ||
if let Some(p) = payload { | ||
request = request.json(p); | ||
}; | ||
let response = request | ||
.send() | ||
.await | ||
.unwrap_or_else(|_| panic!("{annotation}Failed to send request")); | ||
|
||
// Assert that status code is success | ||
if !response.status().is_success() { | ||
let status = response.status(); | ||
let message = response.text().await.expect("Read body as text"); | ||
panic!("({annotation})Invalid status code (success expected). Status: {status}. Message: {message}") | ||
} | ||
|
||
response | ||
.json() | ||
.await | ||
.unwrap_or_else(|_| panic!("({annotation})Response deserialization failed")) | ||
} | ||
|
||
pub async fn send_annotated_post_request<Response: for<'a> serde::Deserialize<'a>>( | ||
url: &Url, | ||
route: &str, | ||
payload: &impl serde::Serialize, | ||
annotation: &str, | ||
) -> Response { | ||
send_annotated_request( | ||
url, | ||
route, | ||
reqwest::Method::POST, | ||
Some(payload), | ||
Some(annotation), | ||
) | ||
.await | ||
} | ||
|
||
pub async fn send_post_request<Response: for<'a> serde::Deserialize<'a>>( | ||
url: &Url, | ||
route: &str, | ||
payload: &impl serde::Serialize, | ||
) -> Response { | ||
send_annotated_request(url, route, reqwest::Method::POST, Some(payload), None).await | ||
} | ||
|
||
pub async fn send_annotated_get_request<Response: for<'a> serde::Deserialize<'a>>( | ||
url: &Url, | ||
route: &str, | ||
annotation: &str, | ||
) -> Response { | ||
send_annotated_request( | ||
url, | ||
route, | ||
reqwest::Method::GET, | ||
None::<&()>, | ||
Some(annotation), | ||
) | ||
.await | ||
} | ||
|
||
pub async fn send_get_request<Response: for<'a> serde::Deserialize<'a>>( | ||
url: &Url, | ||
route: &str, | ||
) -> Response { | ||
send_annotated_request(url, route, reqwest::Method::GET, None::<&()>, None).await | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters