Skip to content

Commit

Permalink
libsql: offline sync retry on server errors only
Browse files Browse the repository at this point in the history
  • Loading branch information
LucioFranco committed Nov 21, 2024
1 parent 26ac07e commit 11cce47
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
5 changes: 3 additions & 2 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ impl SyncContext {
.await
.map_err(SyncError::HttpDispatch)?;

// TODO(lucio): only retry on server side errors
if res.status().is_success() {
let res_body = hyper::body::to_bytes(res.into_body())
.await
Expand All @@ -165,7 +164,9 @@ impl SyncContext {
return Ok(max_frame_no as u32);
}

if nr_retries > max_retries {
// If we've retried too many times or the error is not a server error,
// return the error.
if nr_retries > max_retries || !res.status().is_server_error() {
let status = res.status();

let res_body = hyper::body::to_bytes(res.into_body())
Expand Down
75 changes: 74 additions & 1 deletion libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::*;
use crate::util::Socket;
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tempfile::tempdir;
use tokio::io::{duplex, AsyncRead, AsyncWrite, DuplexStream};
use tower::Service;
use std::time::Duration;

#[tokio::test]
async fn test_sync_context_push_frame() {
Expand Down Expand Up @@ -131,6 +132,50 @@ async fn test_sync_context_corrupted_metadata() {
assert_eq!(sync_ctx.generation(), 1);
}

#[tokio::test]
async fn test_sync_context_retry_on_error() {
// Pause time to control it manually
tokio::time::pause();

let server = MockServer::start();
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("test.db");

let sync_ctx = SyncContext::new(
server.connector(),
db_path.to_str().unwrap().to_string(),
server.url(),
None,
)
.await
.unwrap();

let mut sync_ctx = sync_ctx;
let frame = Bytes::from("test frame data");

// Set server to return errors
server.return_error.store(true, Ordering::SeqCst);

// First attempt should fail but retry
let result = sync_ctx.push_one_frame(frame.clone(), 1, 0).await;
assert!(result.is_err());

// Advance time to trigger retries faster
tokio::time::advance(Duration::from_secs(2)).await;

// Verify multiple requests were made (retries occurred)
assert!(server.request_count() > 1);

// Allow the server to succeed
server.return_error.store(false, Ordering::SeqCst);

// Next attempt should succeed
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1);
assert_eq!(server.frame_count(), 1);
}

#[test]
fn test_hash_verification() {
let mut metadata = MetadataJson {
Expand Down Expand Up @@ -212,11 +257,15 @@ struct MockServer {
url: String,
frame_count: Arc<AtomicU32>,
connector: ConnectorService,
return_error: Arc<AtomicBool>,
request_count: Arc<AtomicU32>,
}

impl MockServer {
fn start() -> Self {
let frame_count = Arc::new(AtomicU32::new(0));
let return_error = Arc::new(AtomicBool::new(false));
let request_count = Arc::new(AtomicU32::new(0));

// Create the mock connector with Some(client_stream)
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
Expand All @@ -227,23 +276,43 @@ impl MockServer {
url: "http://mock.server".to_string(),
frame_count: frame_count.clone(),
connector,
return_error: return_error.clone(),
request_count: request_count.clone(),
};

// Spawn the server handler
let frame_count_clone = frame_count.clone();
let return_error_clone = return_error.clone();
let request_count_clone = request_count.clone();

tokio::spawn(async move {
while let Some(server_stream) = rx.recv().await {
let frame_count_clone = frame_count_clone.clone();
let return_error_clone = return_error_clone.clone();
let request_count_clone = request_count_clone.clone();

tokio::spawn(async move {
use hyper::server::conn::Http;
use hyper::service::service_fn;

let frame_count_clone = frame_count_clone.clone();
let return_error_clone = return_error_clone.clone();
let request_count_clone = request_count_clone.clone();
let service = service_fn(move |req: http::Request<Body>| {
let frame_count = frame_count_clone.clone();
let return_error = return_error_clone.clone();
let request_count = request_count_clone.clone();
async move {
request_count.fetch_add(1, Ordering::SeqCst);
if return_error.load(Ordering::SeqCst) {
return Ok::<_, hyper::Error>(
http::Response::builder()
.status(500)
.body(Body::from("Internal Server Error"))
.unwrap(),
);
}

let current_count = frame_count.fetch_add(1, Ordering::SeqCst);

if req.uri().path().contains("/sync/") {
Expand Down Expand Up @@ -287,6 +356,10 @@ impl MockServer {
fn frame_count(&self) -> u32 {
self.frame_count.load(Ordering::SeqCst)
}

fn request_count(&self) -> u32 {
self.request_count.load(Ordering::SeqCst)
}
}

// Mock connection that implements the Socket trait
Expand Down

0 comments on commit 11cce47

Please sign in to comment.