Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libsql: rework sync v2 structure #1820

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions libsql/src/database/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,19 @@ cfg_sync! {

let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned();

// TODO: add config to set custom connector
let https = super::connector()?;

use tower::ServiceExt;

let svc = https
.map_err(|e| e.into())
.map_response(|s| Box::new(s) as Box<dyn crate::util::Socket>);

let connector = crate::util::ConnectorService::new(svc);

let db = crate::local::Database::open_local_with_offline_writes(
connector,
path,
flags,
url,
Expand Down
34 changes: 34 additions & 0 deletions libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};

use crate::TransactionBehavior;

use bytes::{BufMut, BytesMut};
use libsql_sys::ffi;
use std::{ffi::c_int, fmt, path::Path, sync::Arc};

Expand Down Expand Up @@ -445,6 +446,39 @@ impl Connection {
}
}
}

pub fn wal_frame_count(&self) -> u32 {
let mut max_frame_no: std::os::raw::c_uint = 0;
unsafe { libsql_sys::ffi::libsql_wal_frame_count(self.handle(), &mut max_frame_no) };

max_frame_no
}

pub fn wal_get_frame(&self, frame_no: u32, page_size: u32) -> Result<BytesMut> {
let frame_size: usize = 24 + page_size as usize;

let mut buf = BytesMut::with_capacity(frame_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this better than just vec![0; frame_size]?


let rc = unsafe {
libsql_sys::ffi::libsql_wal_get_frame(
self.handle(),
frame_no,
buf.chunk_mut().as_mut_ptr() as *mut _,
frame_size as u32,
)
};

if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
format!("Failed to get frame: {}", frame_no),
));
}

unsafe { buf.advance_mut(frame_size) };

Ok(buf)
}
}

impl fmt::Debug for Connection {
Expand Down
83 changes: 30 additions & 53 deletions libsql/src/local/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct Database {
#[cfg(feature = "replication")]
pub replication_ctx: Option<ReplicationContext>,
#[cfg(feature = "sync")]
pub sync_ctx: Option<SyncContext>,
pub sync_ctx: Option<tokio::sync::Mutex<SyncContext>>,
}

impl Database {
Expand Down Expand Up @@ -131,6 +131,7 @@ impl Database {
#[cfg(feature = "sync")]
#[doc(hidden)]
pub async fn open_local_with_offline_writes(
connector: crate::util::ConnectorService,
db_path: impl Into<String>,
flags: OpenFlags,
endpoint: String,
Expand All @@ -143,7 +144,10 @@ impl Database {
endpoint
};
let mut db = Database::open(&db_path, flags)?;
db.sync_ctx = Some(SyncContext::new(endpoint, Some(auth_token)));

let ctx = SyncContext::new(endpoint, Some(auth_token), &db_path, connector).await;

db.sync_ctx = Some(tokio::sync::Mutex::new(ctx));
Ok(db)
}

Expand Down Expand Up @@ -320,7 +324,10 @@ impl Database {

#[cfg(feature = "replication")]
/// Sync with primary at least to a given replication index
pub async fn sync_until(&self, replication_index: FrameNo) -> Result<crate::database::Replicated> {
pub async fn sync_until(
&self,
replication_index: FrameNo,
) -> Result<crate::database::Replicated> {
Comment on lines +327 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid reformattings like that in PRs that are supposed to do something non-trivial.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really odd, my editor is auto formatting this but with cargo fmt which we use to check on CI so we avoid these changes its not formatting lines with #[cfg(feature = "")]. Gonna take a look into this, seems really odd that cargo fmt isn't picking this up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't invest time into that. Just when you see stuff like this put it on a separate PR and even merge it without review.

if let Some(ctx) = &self.replication_ctx {
let mut frame_no: Option<FrameNo> = ctx.replicator.committed_frame_no().await;
let mut frames_synced: usize = 0;
Expand Down Expand Up @@ -380,84 +387,54 @@ impl Database {
#[cfg(feature = "sync")]
/// Push WAL frames to remote.
pub async fn push(&self) -> Result<crate::database::Replicated> {
let sync_ctx = self.sync_ctx.as_ref().unwrap();
let mut ctx = match &self.sync_ctx {
Some(ctx) => ctx.lock().await,
None => panic!("sync context not set"),
};

let conn = self.connect()?;

// TODO: can this be cached?
let page_size = {
let rows = conn.query("PRAGMA page_size", crate::params::Params::None)?.unwrap();
let rows = conn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again - please avoid noise like that in PRs

.query("PRAGMA page_size", crate::params::Params::None)?
.unwrap();
let row = rows.next()?.unwrap();
let page_size = row.get::<u32>(0)?;
page_size
};

let mut max_frame_no: std::os::raw::c_uint = 0;
unsafe { libsql_sys::ffi::libsql_wal_frame_count(conn.handle(), &mut max_frame_no) };

let max_frame_no = conn.wal_frame_count();

let generation = 1; // TODO: Probe from WAL.
let start_frame_no = sync_ctx.durable_frame_num + 1;
let start_frame_no = ctx.durable_frame_num() + 1;
let end_frame_no = max_frame_no;

// TODO: figure out relation to durable_frame_num
// let max_frame_no = ctx.max_frame_no();
Comment on lines +413 to +414
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@penberg before we can merge this would be helpful to clairify the difference between durable_frame_ num and max_frame_no. What I understand is we want to write the durable_frame_num here (which I guess is the max_frame_no from the server) to disk to ensure we start at a more advanced wal frame than 0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You always have two different max frame numbers: one on client, one on server. I tend to call the latter durable_frame_num.


let mut frame_no = start_frame_no;
while frame_no <= end_frame_no {
// The server returns its maximum frame number. To avoid resending
// frames the server already knows about, we need to update the
// frame number to the one returned by the server.
let max_frame_no = self.push_one_frame(&conn, &sync_ctx, generation, frame_no, page_size).await?;
let frame = conn.wal_get_frame(frame_no, page_size)?;

let max_frame_no = ctx.send_frame(frame.freeze(), generation, frame_no).await?;

if max_frame_no > frame_no {
frame_no = max_frame_no;
}
frame_no += 1;
}

let frame_count = end_frame_no - start_frame_no + 1;
Ok(crate::database::Replicated{
Ok(crate::database::Replicated {
frame_no: None,
frames_synced: frame_count as usize,
})
}

#[cfg(feature = "sync")]
async fn push_one_frame(&self, conn: &Connection, sync_ctx: &SyncContext, generation: u32, frame_no: u32, page_size: u32) -> Result<u32> {
let frame_size: usize = 24+page_size as usize;
let frame = vec![0; frame_size];
let rc = unsafe {
libsql_sys::ffi::libsql_wal_get_frame(conn.handle(), frame_no, frame.as_ptr() as *mut _, frame_size as u32)
};
if rc != 0 {
return Err(crate::errors::Error::SqliteFailure(rc as std::ffi::c_int, format!("Failed to get frame: {}", frame_no)));
}
let uri = format!("{}/sync/{}/{}/{}", sync_ctx.sync_url, generation, frame_no, frame_no+1);
let max_frame_no = self.push_with_retry(uri, &sync_ctx.auth_token, frame.to_vec(), sync_ctx.max_retries).await?;
Ok(max_frame_no)
}

#[cfg(feature = "sync")]
async fn push_with_retry(&self, uri: String, auth_token: &Option<String>, frame: Vec<u8>, max_retries: usize) -> Result<u32> {
let mut nr_retries = 0;
loop {
let client = reqwest::Client::new();
let mut builder = client.post(uri.to_owned());
match auth_token {
Some(ref auth_token) => {
builder = builder.header("Authorization", format!("Bearer {}", auth_token.to_owned()));
}
None => {}
}
let res = builder.body(frame.to_vec()).send().await.unwrap();
if res.status().is_success() {
let resp = res.json::<serde_json::Value>().await.unwrap();
let max_frame_no = resp.get("max_frame_no").unwrap().as_u64().unwrap();
return Ok(max_frame_no as u32);
}
if nr_retries > max_retries {
return Err(crate::errors::Error::ConnectionFailed(format!("Failed to push frame: {}", res.status())));
}
let delay = std::time::Duration::from_millis(100 * (1 << nr_retries));
tokio::time::sleep(delay).await;
nr_retries += 1;
}
}

pub(crate) fn path(&self) -> &str {
&self.db_path
}
Expand Down
143 changes: 137 additions & 6 deletions libsql/src/sync.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,149 @@
const DEFAULT_MAX_RETRIES: usize = 5;

use bytes::Bytes;
use http::{HeaderValue, Request, Uri};
use tokio::sync::Mutex;

use crate::{util::ConnectorService, Result};

pub struct SyncContext {
pub sync_url: String,
pub auth_token: Option<String>,
pub max_retries: usize,
pub durable_frame_num: u32,
sync_url: String,
auth_token: Option<String>,
max_retries: usize,
durable_frame_num: u32,
db_path: String,
max_frame_no: u32,

client: hyper::Client<ConnectorService, hyper::Body>,
}

impl SyncContext {
pub fn new(sync_url: String, auth_token: Option<String>) -> Self {
Self {
pub async fn new(
sync_url: String,
auth_token: Option<String>,
db_path: impl Into<String>,
connector: ConnectorService,
) -> Self {
let mut ctx = Self {
sync_url,
auth_token,
durable_frame_num: 0,
max_retries: DEFAULT_MAX_RETRIES,
db_path: db_path.into(),
max_frame_no: 0,
client: hyper::Client::builder().build(connector),
};

ctx.read_and_update_metadata().await.unwrap();

ctx
}

pub(crate) async fn send_frame(
&mut self,
frame: Bytes,
generation: u32,
frame_no: u32,
) -> Result<u32> {
let url = format!(
"{}/sync/{}/{}/{}",
self.sync_url,
generation,
frame_no,
frame_no + 1
);

let maybe_auth_header = if let Some(auth_token) = &self.auth_token {
Some(HeaderValue::from_str(&format!("Bearer {}", auth_token)).unwrap())
} else {
None
};

let mut attempts = 0;

loop {
let mut req = Request::post(url.clone());

if let Some(auth_header) = &maybe_auth_header {
req.headers_mut()
.unwrap()
.insert("Authorization", auth_header.clone());
}

let req = req.body(frame.clone().into()).unwrap();

let res = self.client.request(req).await.unwrap();

if res.status().is_success() {
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();

let resp = serde_json::from_slice::<serde_json::Value>(&body[..]).unwrap();

let max_frame_no = resp.get("max_frame_no").unwrap().as_u64().unwrap() as u32;

// Update our best known max_frame_no from the server and write it to disk.
self.set_max_frame_no(max_frame_no).await.unwrap();

return Ok(max_frame_no);
} else if res.status().is_server_error() || attempts < self.max_retries {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@penberg small change compared to what you had, I just retry on server failures rather than all failures like auth_token being invalid.

let delay = std::time::Duration::from_millis(100 * (1 << attempts));
tokio::time::sleep(delay).await;
attempts += 1;

continue;
} else {
return Err(crate::errors::Error::ConnectionFailed(format!(
"Failed to push frame: {}",
res.status()
)));
}
}
}

pub(crate) fn max_frame_no(&self) -> u32 {
self.max_frame_no
}

pub(crate) fn durable_frame_num(&self) -> u32 {
self.durable_frame_num
}

pub(crate) async fn set_max_frame_no(&mut self, max_frame_no: u32) -> Result<()> {
// TODO: check if max_frame_no is larger than current known max_frame_no
self.max_frame_no = max_frame_no;

self.update_metadata().await?;

Ok(())
}

async fn update_metadata(&mut self) -> Result<()> {
let path = format!("{}-info", self.db_path);

let contents = serde_json::to_vec(&MetadataJson {
max_frame_no: self.max_frame_no,
})
.unwrap();

tokio::fs::write(path, contents).await.unwrap();

Ok(())
}

async fn read_and_update_metadata(&mut self) -> Result<()> {
let path = format!("{}-info", self.db_path);

let contents = tokio::fs::read(&path).await.unwrap();

let metadata = serde_json::from_slice::<MetadataJson>(&contents[..]).unwrap();

self.max_frame_no = metadata.max_frame_no;

Ok(())
}
Comment on lines +111 to +143
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's all of this? Can we have one PR that just switches to hyper and all other changes in another PR please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also what's this?

}

#[derive(serde::Serialize, serde::Deserialize)]
struct MetadataJson {
max_frame_no: u32,
}
Loading