Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libsql: restart sync on lower frame_no from remote #1847

Merged
merged 1 commit into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,29 @@ impl Database {
#[cfg(feature = "sync")]
/// Push WAL frames to remote.
pub async fn push(&self) -> Result<crate::database::Replicated> {
use crate::sync::SyncError;
use crate::Error;

match self.try_push().await {
Ok(rep) => Ok(rep),
Err(Error::Sync(err)) => {
// Retry the sync because we are ahead of the server and we need to push some older
// frames.
if let Some(SyncError::InvalidPushFrameNoLow(_, _)) =
err.downcast_ref::<SyncError>()
{
tracing::debug!("got InvalidPushFrameNo, retrying push");
self.try_push().await
} else {
Err(Error::Sync(err))
}
}
Err(e) => Err(e),
}
}

#[cfg(feature = "sync")]
async fn try_push(&self) -> Result<crate::database::Replicated> {
let mut sync_ctx = self.sync_ctx.as_ref().unwrap().lock().await;
let conn = self.connect()?;

Expand Down
41 changes: 39 additions & 2 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ pub enum SyncError {
VerifyVersion(u32, u32),
#[error("failed to verify metadata file hash: expected={0}, got={1}")]
VerifyHash(u32, u32),
#[error("server returned a lower frame_no: sent={0}, got={1}")]
InvalidPushFrameNoLow(u32, u32),
#[error("server returned a higher frame_no: sent={0}, got={1}")]
InvalidPushFrameNoHigh(u32, u32),
}

impl SyncError {
Expand Down Expand Up @@ -91,7 +95,10 @@ impl SyncContext {
};

if let Err(e) = me.read_metadata().await {
tracing::error!("failed to read sync metadata file: {}", e);
tracing::error!(
"failed to read sync metadata file, resetting back to defaults: {}",
e
);
}

Ok(me)
Expand All @@ -115,6 +122,30 @@ impl SyncContext {

let durable_frame_num = self.push_with_retry(uri, frame, self.max_retries).await?;

if durable_frame_num > frame_no {
tracing::error!(
"server returned durable_frame_num larger than what we sent: sent={}, got={}",
frame_no,
durable_frame_num
);

return Err(SyncError::InvalidPushFrameNoHigh(frame_no, durable_frame_num).into());
}

if durable_frame_num < frame_no {
// Update our knowledge of where the server is at frame wise.
self.durable_frame_num = durable_frame_num;

tracing::debug!(
"server returned durable_frame_num lower than what we sent: sent={}, got={}",
frame_no,
durable_frame_num
);

// Return an error and expect the caller to re-call push with the updated state.
return Err(SyncError::InvalidPushFrameNoLow(frame_no, durable_frame_num).into());
}

tracing::debug!(?durable_frame_num, "frame successfully pushed");

// Update our last known max_frame_no from the server.
Expand Down Expand Up @@ -232,14 +263,20 @@ impl SyncContext {
return Err(SyncError::VerifyVersion(metadata.version, METADATA_VERSION).into());
}

tracing::debug!(
"read sync metadata for db_path={:?}, metadata={:?}",
self.db_path,
metadata
);

self.durable_frame_num = metadata.durable_frame_num;
self.generation = metadata.generation;

Ok(())
}
}

#[derive(serde::Serialize, serde::Deserialize)]
#[derive(serde::Serialize, serde::Deserialize, Debug)]
struct MetadataJson {
hash: u32,
version: u32,
Expand Down
79 changes: 69 additions & 10 deletions libsql/src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tempfile::tempdir;
use tokio::io::{duplex, AsyncRead, AsyncWrite, DuplexStream};
use tower::Service;
use std::time::Duration;

#[tokio::test]
async fn test_sync_context_push_frame() {
Expand All @@ -30,10 +30,10 @@ async fn test_sync_context_push_frame() {
// Push a frame and verify the response
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1); // First frame should return max_frame_no = 1
assert_eq!(durable_frame, 0); // First frame should return max_frame_no = 0

// Verify internal state was updated
assert_eq!(sync_ctx.durable_frame_num(), 1);
assert_eq!(sync_ctx.durable_frame_num(), 0);
assert_eq!(sync_ctx.generation(), 1);
assert_eq!(server.frame_count(), 1);
}
Expand All @@ -58,7 +58,7 @@ async fn test_sync_context_with_auth() {

let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1);
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
}

Expand All @@ -84,8 +84,8 @@ async fn test_sync_context_multiple_frames() {
let frame = Bytes::from(format!("frame data {}", i));
let durable_frame = sync_ctx.push_one_frame(frame, 1, i).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, i + 1);
assert_eq!(sync_ctx.durable_frame_num(), i + 1);
assert_eq!(durable_frame, i);
assert_eq!(sync_ctx.durable_frame_num(), i);
assert_eq!(server.frame_count(), i + 1);
}
}
Expand All @@ -110,7 +110,7 @@ async fn test_sync_context_corrupted_metadata() {
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1);
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);

// Update metadata path to use -info instead of .meta
Expand All @@ -132,11 +132,69 @@ async fn test_sync_context_corrupted_metadata() {
assert_eq!(sync_ctx.generation(), 1);
}

#[tokio::test]
async fn test_sync_restarts_with_lower_max_frame_no() {
let _ = tracing_subscriber::fmt::try_init();

let server = MockServer::start();
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("test.db");

// Create initial sync context and push a frame
let sync_ctx = SyncContext::new(
server.connector(),
db_path.to_str().unwrap().to_string(),
server.url(),
None,
)
.await
.unwrap();

let mut sync_ctx = sync_ctx;
let frame = Bytes::from("test frame data");
let durable_frame = sync_ctx.push_one_frame(frame.clone(), 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);

// Bump the durable frame num so that the next time we call the
// server we think we are further ahead than the database we are talking to is.
sync_ctx.durable_frame_num += 3;
sync_ctx.write_metadata().await.unwrap();

// Create new sync context with corrupted metadata
let mut sync_ctx = SyncContext::new(
server.connector(),
db_path.to_str().unwrap().to_string(),
server.url(),
None,
)
.await
.unwrap();

// Verify that the context was set to new fake values.
assert_eq!(sync_ctx.durable_frame_num(), 3);
assert_eq!(sync_ctx.generation(), 1);

let frame_no = sync_ctx.durable_frame_num() + 1;
// This push should fail because we are ahead of the server and thus should get an invalid
// frame no error.
sync_ctx
.push_one_frame(frame.clone(), 1, frame_no)
.await
.unwrap_err();

let frame_no = sync_ctx.durable_frame_num() + 1;
// This then should work because when the last one failed it updated our state of the server
// durable_frame_num and we should then start writing from there.
sync_ctx.push_one_frame(frame, 1, frame_no).await.unwrap();
}

#[tokio::test]
async fn test_sync_context_retry_on_error() {
// Pause time to control it manually
tokio::time::pause();

let server = MockServer::start();
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("test.db");
Expand Down Expand Up @@ -172,7 +230,7 @@ async fn test_sync_context_retry_on_error() {
// Next attempt should succeed
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
sync_ctx.write_metadata().await.unwrap();
assert_eq!(durable_frame, 1);
assert_eq!(durable_frame, 0);
assert_eq!(server.frame_count(), 1);
}

Expand Down Expand Up @@ -316,8 +374,9 @@ impl MockServer {
let current_count = frame_count.fetch_add(1, Ordering::SeqCst);

if req.uri().path().contains("/sync/") {
// Return the max_frame_no that has been accepted
let response = serde_json::json!({
"max_frame_no": current_count + 1
"max_frame_no": current_count
});

Ok::<_, hyper::Error>(
Expand Down
Loading