Skip to content

Commit

Permalink
Merge pull request #1847 from tursodatabase/lucio/1838
Browse files Browse the repository at this point in the history
libsql: restart sync on lower frame_no from remote
  • Loading branch information
penberg authored Nov 30, 2024
2 parents 15849e4 + 6158cc2 commit 9241b00
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 12 deletions.
23 changes: 23 additions & 0 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,29 @@ impl Database {
#[cfg(feature = "sync")]
/// Push WAL frames to remote.
pub async fn push(&self) -> Result<crate::database::Replicated> {
use crate::sync::SyncError;
use crate::Error;

match self.try_push().await {
Ok(rep) => Ok(rep),
Err(Error::Sync(err)) => {
// Retry the sync because we are ahead of the server and we need to push some older
// frames.
if let Some(SyncError::InvalidPushFrameNoLow(_, _)) =
err.downcast_ref::<SyncError>()
{
tracing::debug!("got InvalidPushFrameNo, retrying push");
self.try_push().await
} else {
Err(Error::Sync(err))
}
}
Err(e) => Err(e),
}
}

#[cfg(feature = "sync")]
async fn try_push(&self) -> Result<crate::database::Replicated> {
let mut sync_ctx = self.sync_ctx.as_ref().unwrap().lock().await;
let conn = self.connect()?;

Expand Down
41 changes: 39 additions & 2 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ pub enum SyncError {
VerifyVersion(u32, u32),
#[error("failed to verify metadata file hash: expected={0}, got={1}")]
VerifyHash(u32, u32),
#[error("server returned a lower frame_no: sent={0}, got={1}")]
InvalidPushFrameNoLow(u32, u32),
#[error("server returned a higher frame_no: sent={0}, got={1}")]
InvalidPushFrameNoHigh(u32, u32),
}

impl SyncError {
Expand Down Expand Up @@ -91,7 +95,10 @@ impl SyncContext {
};

if let Err(e) = me.read_metadata().await {
tracing::error!("failed to read sync metadata file: {}", e);
tracing::error!(
"failed to read sync metadata file, resetting back to defaults: {}",
e
);
}

Ok(me)
Expand All @@ -115,6 +122,30 @@ impl SyncContext {

let durable_frame_num = self.push_with_retry(uri, frame, self.max_retries).await?;

if durable_frame_num > frame_no {
tracing::error!(
"server returned durable_frame_num larger than what we sent: sent={}, got={}",
frame_no,
durable_frame_num
);

return Err(SyncError::InvalidPushFrameNoHigh(frame_no, durable_frame_num).into());
}

if durable_frame_num < frame_no {
// Update our knowledge of where the server is at frame wise.
self.durable_frame_num = durable_frame_num;

tracing::debug!(
"server returned durable_frame_num lower than what we sent: sent={}, got={}",
frame_no,
durable_frame_num
);

// Return an error and expect the caller to re-call push with the updated state.
return Err(SyncError::InvalidPushFrameNoLow(frame_no, durable_frame_num).into());
}

tracing::debug!(?durable_frame_num, "frame successfully pushed");

// Update our last known max_frame_no from the server.
Expand Down Expand Up @@ -232,14 +263,20 @@ impl SyncContext {
return Err(SyncError::VerifyVersion(metadata.version, METADATA_VERSION).into());
}

tracing::debug!(
"read sync metadata for db_path={:?}, metadata={:?}",
self.db_path,
metadata
);

self.durable_frame_num = metadata.durable_frame_num;
self.generation = metadata.generation;

Ok(())
}
}

#[derive(serde::Serialize, serde::Deserialize)]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
struct MetadataJson {
hash: u32,
version: u32,
Expand Down
79 changes: 69 additions & 10 deletions libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tempfile::tempdir;
use tokio::io::{duplex, AsyncRead, AsyncWrite, DuplexStream};
use tower::Service;
use std::time::Duration;

#[tokio::test]
async fn test_sync_context_push_frame() {
Expand All @@ -30,10 +30,10 @@ async fn test_sync_context_push_frame() {
// Push a frame and verify the response
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1); // First frame should return max_frame_no = 1
assert_eq!(durable_frame, 0); // First frame should return max_frame_no = 0

// Verify internal state was updated
assert_eq!(sync_ctx.durable_frame_num(), 1);
assert_eq!(sync_ctx.durable_frame_num(), 0);
assert_eq!(sync_ctx.generation(), 1);
assert_eq!(server.frame_count(), 1);
}
Expand All @@ -58,7 +58,7 @@ async fn test_sync_context_with_auth() {

let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1);
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
}

Expand All @@ -84,8 +84,8 @@ async fn test_sync_context_multiple_frames() {
let frame = Bytes::from(format!("frame data {}", i));
let durable_frame = sync_ctx.push_one_frame(frame, 1, i).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, i + 1);
assert_eq!(sync_ctx.durable_frame_num(), i + 1);
assert_eq!(durable_frame, i);
assert_eq!(sync_ctx.durable_frame_num(), i);
assert_eq!(server.frame_count(), i + 1);
}
}
Expand All @@ -110,7 +110,7 @@ async fn test_sync_context_corrupted_metadata() {
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1);
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);

// Update metadata path to use -info instead of .meta
Expand All @@ -132,11 +132,69 @@ async fn test_sync_context_corrupted_metadata() {
assert_eq!(sync_ctx.generation(), 1);
}

#[tokio::test]
async fn test_sync_restarts_with_lower_max_frame_no() {
let _ = tracing_subscriber::fmt::try_init();

let server = MockServer::start();
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("test.db");

// Create initial sync context and push a frame
let sync_ctx = SyncContext::new(
server.connector(),
db_path.to_str().unwrap().to_string(),
server.url(),
None,
)
.await
.unwrap();

let mut sync_ctx = sync_ctx;
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame.clone(), 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);

// Bump the durable frame num so that the next time we call the
// server we think we are further ahead than the database we are talking to is.
sync_ctx.durable_frame_num += 3;
sync_ctx.write_metadata().await.unwrap();

// Create new sync context with corrupted metadata
let mut sync_ctx = SyncContext::new(
server.connector(),
db_path.to_str().unwrap().to_string(),
server.url(),
None,
)
.await
.unwrap();

// Verify that the context was set to new fake values.
assert_eq!(sync_ctx.durable_frame_num(), 3);
assert_eq!(sync_ctx.generation(), 1);

let frame_no = sync_ctx.durable_frame_num() + 1;
// This push should fail because we are ahead of the server and thus should get an invalid
// frame no error.
sync_ctx
.push_one_frame(frame.clone(), 1, frame_no)
.await
.unwrap_err();

let frame_no = sync_ctx.durable_frame_num() + 1;
// This then should work because when the last one failed it updated our state of the server
// durable_frame_num and we should then start writing from there.
sync_ctx.push_one_frame(frame, 1, frame_no).await.unwrap();
}

#[tokio::test]
async fn test_sync_context_retry_on_error() {
// Pause time to control it manually
tokio::time::pause();

let server = MockServer::start();
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("test.db");
Expand Down Expand Up @@ -172,7 +230,7 @@ async fn test_sync_context_retry_on_error() {
// Next attempt should succeed
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1);
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
}

Expand Down Expand Up @@ -316,8 +374,9 @@ impl MockServer {
let current_count = frame_count.fetch_add(1, Ordering::SeqCst);

if req.uri().path().contains("/sync/") {
// Return the max_frame_no that has been accepted
let response = serde_json::json!({
"max_frame_no": current_count + 1
"max_frame_no": current_count
});

Ok::<_, hyper::Error>(
Expand Down

0 comments on commit 9241b00

Please sign in to comment.