Skip to content

Commit

Permalink
feat: add callback to client proving server trait
Browse files Browse the repository at this point in the history
  • Loading branch information
rpalakkal committed May 3, 2024
1 parent 7bac127 commit 494cc83
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
1 change: 1 addition & 0 deletions sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
#![allow(incomplete_features)]
#![feature(associated_type_defaults)]
#![feature(async_fn_in_trait)]
mod api;

// pub(crate) mod utils;
Expand Down
63 changes: 52 additions & 11 deletions sdk/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use axiom_circuit::{
use ethers::providers::{Http, Provider};
use rocket::State;
use serde::de::DeserializeOwned;
use tokio::runtime::Runtime;

use self::types::{
AggregationCircuitCtx, AxiomComputeCircuitCtx, AxiomComputeCtx, AxiomComputeJobStatus,
Expand All @@ -31,13 +32,24 @@ use crate::{
pub mod types;

pub trait AxiomClientProvingServer: AxiomCircuitScaffold<Http, Fr> {
type Context = ();
type Context: Clone = ();
type ServerPayload = ();
type RequestInput: DeserializeOwned;

fn construct_context(params: Self::CoreParams) -> Self::Context;
fn process_input(ctx: Self::Context, input: Self::RequestInput) -> Self::InputValue;
async fn process_input(
ctx: Self::Context,
input: Self::RequestInput,
provider: Provider<Http>,
) -> Result<(Self::InputValue, Self::ServerPayload), String>;
#[allow(unused_variables)]
fn after_prove(ctx: Self::Context, output: AxiomV2CircuitOutput) {}
fn callback(
ctx: Self::Context,
payload: Self::ServerPayload,
output: AxiomV2CircuitOutput,
) -> Result<(), ()> {
Ok(())
}
}

impl<A: AxiomComputeFn> AxiomClientProvingServer for AxiomCompute<A>
Expand All @@ -50,8 +62,12 @@ where
fn construct_context(_: A::CoreParams) -> Self::Context {
()
}
fn process_input(_ctx: Self::Context, input: Self::RequestInput) -> A::Input<Fr> {
input.into()
async fn process_input(
_ctx: Self::Context,
input: Self::RequestInput,
_provider: Provider<Http>,
) -> Result<(A::Input<Fr>, ()), String> {
Ok((input.into(), ()))
}
}

Expand All @@ -73,7 +89,7 @@ pub async fn add_job(ctx: &State<AxiomComputeManager>, job: String) -> u64 {

pub fn prover_loop<A: AxiomClientProvingServer>(
manager: AxiomComputeManager,
ctx: AxiomComputeCtx<A::CoreParams>,
ctx: AxiomComputeCtx<A>,
mut shutdown: tokio::sync::mpsc::Receiver<()>,
) {
loop {
Expand All @@ -89,8 +105,19 @@ pub fn prover_loop<A: AxiomClientProvingServer>(
};
let raw_input = inputs.get(&job).unwrap();
let input: A::RequestInput = serde_json::from_str(raw_input).unwrap();
let server_ctx = A::construct_context(ctx.child.pinning.core_params.clone());
let processed_input = A::process_input(server_ctx, input);
let rt = Runtime::new().unwrap();
let processed_input = rt.block_on(async {
A::process_input(ctx.server_ctx.clone(), input, ctx.provider.clone()).await
});
if processed_input.is_err() {
manager
.job_status
.lock()
.unwrap()
.insert(job, AxiomComputeJobStatus::ErrorInputPrep);
continue;
}
let (processed_input, payload) = processed_input.unwrap();
let runner = AxiomCircuit::<Fr, Http, A>::prover(
ctx.provider.clone(),
ctx.child.pinning.clone(),
Expand Down Expand Up @@ -124,12 +151,23 @@ pub fn prover_loop<A: AxiomClientProvingServer>(
} else {
inner_output
};
manager.outputs.lock().unwrap().insert(job, output);
manager.outputs.lock().unwrap().insert(job, output.clone());
manager
.job_status
.lock()
.unwrap()
.insert(job, AxiomComputeJobStatus::OutputReady);
let callback = A::callback(ctx.server_ctx.clone(), payload, output);
let callback_status = if callback.is_ok() {
AxiomComputeJobStatus::CallbackSuccess
} else {
AxiomComputeJobStatus::ErrorCallback
};
manager
.job_status
.lock()
.unwrap()
.insert(job, callback_status);
} else {
if shutdown.try_recv().is_ok() {
break;
Expand All @@ -139,9 +177,9 @@ pub fn prover_loop<A: AxiomClientProvingServer>(
}
}

pub fn initialize<A: AxiomCircuitScaffold<Http, Fr>>(
pub fn initialize<A: AxiomClientProvingServer>(
options: AxiomComputeServerCmd,
) -> AxiomComputeCtx<A::CoreParams> {
) -> AxiomComputeCtx<A> {
let data_path = PathBuf::from(options.data_path);
let srs_path = PathBuf::from(options.srs_path);
let metadata =
Expand All @@ -168,6 +206,8 @@ pub fn initialize<A: AxiomCircuitScaffold<Http, Fr>>(
}
});

let server_ctx = A::construct_context(pinning.clone().core_params);

AxiomComputeCtx {
child: AxiomComputeCircuitCtx {
pk,
Expand All @@ -176,6 +216,7 @@ pub fn initialize<A: AxiomCircuitScaffold<Http, Fr>>(
},
agg,
provider,
server_ctx,
}
}
#[rocket::get("/job_status/<id>")]
Expand Down
14 changes: 10 additions & 4 deletions sdk/src/server/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use clap::Parser;
use ethers::providers::{Http, Provider};
use serde::Serialize;

use super::AxiomClientProvingServer;

#[derive(Clone, Debug)]
pub struct AxiomComputeCircuitCtx<CoreParams> {
pub pk: ProvingKey<G1Affine>,
Expand All @@ -31,10 +33,11 @@ pub struct AggregationCircuitCtx<CoreParams> {
}

#[derive(Clone, Debug)]
pub struct AxiomComputeCtx<CoreParams> {
pub child: AxiomComputeCircuitCtx<CoreParams>,
pub agg: Option<AggregationCircuitCtx<CoreParams>>,
pub struct AxiomComputeCtx<A: AxiomClientProvingServer> {
pub child: AxiomComputeCircuitCtx<A::CoreParams>,
pub agg: Option<AggregationCircuitCtx<A::CoreParams>>,
pub provider: Provider<Http>,
pub server_ctx: A::Context,
}

#[derive(Clone, Debug, Serialize)]
Expand All @@ -43,7 +46,10 @@ pub enum AxiomComputeJobStatus {
DataQueryReady,
InnerOutputReady,
OutputReady,
Error,
CallbackSuccess,
ErrorInputPrep,
ErrorCallback,
ErrorInnerCircuit,
}

#[derive(Clone, Debug, Default)]
Expand Down

0 comments on commit 494cc83

Please sign in to comment.