diff --git a/operator/merkle_tree/lib/merkle_tree.h b/operator/merkle_tree/lib/merkle_tree.h index 46bd2b5fd..39ddf474e 100644 --- a/operator/merkle_tree/lib/merkle_tree.h +++ b/operator/merkle_tree/lib/merkle_tree.h @@ -1,3 +1,4 @@ #include +#include -bool verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root); +int32_t verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root); diff --git a/operator/merkle_tree/lib/src/lib.rs b/operator/merkle_tree/lib/src/lib.rs index 415b6363c..15f930899 100644 --- a/operator/merkle_tree/lib/src/lib.rs +++ b/operator/merkle_tree/lib/src/lib.rs @@ -4,8 +4,7 @@ use aligned_sdk::core::types::{ use lambdaworks_crypto::merkle_tree::merkle::MerkleTree; use log::error; -#[no_mangle] -pub extern "C" fn verify_merkle_tree_batch_ffi( +fn inner_verify_merkle_tree_batch_ffi( batch_ptr: *const u8, batch_len: usize, merkle_root: &[u8; 32], @@ -53,6 +52,22 @@ pub extern "C" fn verify_merkle_tree_batch_ffi( computed_batch_merkle_tree.root == *merkle_root } +#[no_mangle] +pub extern "C" fn verify_merkle_tree_batch_ffi( + batch_ptr: *const u8, + batch_len: usize, + merkle_root: &[u8; 32], +) -> i32 { + let result = std::panic::catch_unwind(|| { + inner_verify_merkle_tree_batch_ffi(batch_ptr, batch_len, merkle_root) + }); + + match result { + Ok(v) => v as i32, + Err(_) => -1, + } +} + #[cfg(test)] mod tests { use super::*; @@ -75,7 +90,7 @@ mod tests { let result = verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root); - assert_eq!(result, true); + assert_eq!(result, 1); } #[test] @@ -92,7 +107,7 @@ mod tests { let result = verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root); - assert_eq!(result, false); + assert_eq!(result, 0); } #[test] @@ -109,6 +124,6 @@ mod tests { let result = verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root); - assert_eq!(result, false); + assert_eq!(result, 0); } } diff --git a/operator/merkle_tree/merkle_tree.go b/operator/merkle_tree/merkle_tree.go index 9b6562481..781b933a6 100644 --- a/operator/merkle_tree/merkle_tree.go +++ b/operator/merkle_tree/merkle_tree.go @@ -8,13 +8,35 @@ package merkle_tree */ import "C" import "unsafe" +import "fmt" -func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) bool { +func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) (isVerified bool, err error) { + // Here we define the return value on failure + isVerified = false + err = nil if len(batchBuffer) == 0 { - return false + return isVerified, err } + // This will catch any go panic + defer func() { + rec := recover() + if rec != nil { + err = fmt.Errorf("Panic was caught while verifying merkle tree batch: %s", rec) + } + }() + batchPtr := (*C.uchar)(unsafe.Pointer(&batchBuffer[0])) merkleRootPtr := (*C.uchar)(unsafe.Pointer(&merkleRootBuffer[0])) - return (bool)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr)) + + r := (C.int32_t)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr)) + + if r == -1 { + err = fmt.Errorf("Panic happened on FFI while verifying merkle tree batch") + return isVerified, err + } + + isVerified = (r == 1) + + return isVerified, err } diff --git a/operator/merkle_tree/merkle_tree_test.go b/operator/merkle_tree/merkle_tree_test.go index 539391929..e6a1ef557 100644 --- a/operator/merkle_tree/merkle_tree_test.go +++ b/operator/merkle_tree/merkle_tree_test.go @@ -32,7 +32,8 @@ func TestVerifyMerkleTreeBatch(t *testing.T) { var merkleRoot [32]byte copy(merkleRoot[:], merkle_root) - if !VerifyMerkleTreeBatch(batchByteValue, merkleRoot) { + verified, err := VerifyMerkleTreeBatch(batchByteValue, merkleRoot) + if err != nil || !verified { t.Errorf("Batch did not verify Merkle Root") } diff --git a/operator/pkg/operator.go b/operator/pkg/operator.go index 8a96b51c4..7dff12bea 100644 --- a/operator/pkg/operator.go +++ b/operator/pkg/operator.go @@ -496,13 +496,13 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool) results <- verificationResult case common.SP1: - verificationResult := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode) - o.Logger.Infof("SP1 proof verification result: %t", verificationResult) - results <- verificationResult + verificationResult, err := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode) + o.handleVerificationResult(results, verificationResult, err, "SP1 proof verification") case common.Risc0: - verificationResult := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof, + verificationResult, err := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof, verificationData.VmProgramCode, verificationData.PubInput) + o.handleVerificationResult(results, verificationResult, err, "RiscZero proof verification") o.Logger.Infof("Risc0 proof verification result: %t", verificationResult) results <- verificationResult @@ -512,6 +512,16 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool) } } +func (o *Operator) handleVerificationResult(results chan bool, isVerified bool, err error, name string) { + if err != nil { + o.Logger.Errorf("%v failed %v", name, err) + results <- false + } else { + o.Logger.Infof("%v result: %t", name, isVerified) + results <- isVerified + } +} + // VerifyPlonkProofBLS12_381 verifies a PLONK proof using BLS12-381 curve. func (o *Operator) verifyPlonkProofBLS12_381(proofBytes []byte, pubInputBytes []byte, verificationKeyBytes []byte) bool { return o.verifyPlonkProof(proofBytes, pubInputBytes, verificationKeyBytes, ecc.BLS12_381) diff --git a/operator/pkg/s3.go b/operator/pkg/s3.go index dafffa418..865115fa7 100644 --- a/operator/pkg/s3.go +++ b/operator/pkg/s3.go @@ -90,9 +90,9 @@ func (o *Operator) getBatchFromDataService(ctx context.Context, batchURL string, // Checks if downloaded merkle root is the same as the expected one o.Logger.Infof("Verifying batch merkle tree...") - merkle_root_check := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot) - if !merkle_root_check { - return nil, fmt.Errorf("merkle root check failed") + merkle_root_check, err := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot) + if err != nil || !merkle_root_check { + return nil, fmt.Errorf("Error while verifying merkle tree batch") } o.Logger.Infof("Batch merkle tree verified") diff --git a/operator/risc_zero/lib/risc_zero.h b/operator/risc_zero/lib/risc_zero.h index 1c00547cf..982e68c8c 100644 --- a/operator/risc_zero/lib/risc_zero.h +++ b/operator/risc_zero/lib/risc_zero.h @@ -1,4 +1,4 @@ #include #include -bool verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len); +int32_t verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len); diff --git a/operator/risc_zero/lib/src/lib.rs b/operator/risc_zero/lib/src/lib.rs index 13be1cafc..d23869cd9 100644 --- a/operator/risc_zero/lib/src/lib.rs +++ b/operator/risc_zero/lib/src/lib.rs @@ -1,8 +1,7 @@ use log::error; use risc0_zkvm::{InnerReceipt, Receipt}; -#[no_mangle] -pub extern "C" fn verify_risc_zero_receipt_ffi( +fn inner_verify_risc_zero_receipt_ffi( inner_receipt_bytes: *const u8, inner_receipt_len: u32, image_id: *const u8, @@ -43,6 +42,32 @@ pub extern "C" fn verify_risc_zero_receipt_ffi( false } +#[no_mangle] +pub extern "C" fn verify_risc_zero_receipt_ffi( + inner_receipt_bytes: *const u8, + inner_receipt_len: u32, + image_id: *const u8, + image_id_len: u32, + public_input: *const u8, + public_input_len: u32, +) -> i32 { + let result = std::panic::catch_unwind(|| { + inner_verify_risc_zero_receipt_ffi( + inner_receipt_bytes, + inner_receipt_len, + image_id, + image_id_len, + public_input, + public_input_len, + ) + }); + + match result { + Ok(v) => v as i32, + Err(_) => -1, + } +} + #[cfg(test)] mod tests { use super::*; @@ -69,7 +94,7 @@ mod tests { public_input, PUBLIC_INPUT.len() as u32, ); - assert!(result) + assert_eq!(result, 1) } #[test] @@ -86,7 +111,7 @@ mod tests { public_input, PUBLIC_INPUT.len() as u32, ); - assert!(!result) + assert_eq!(result, 0) } #[test] @@ -103,6 +128,6 @@ mod tests { public_input, 0, ); - assert!(!result) + assert_eq!(result, 0) } } diff --git a/operator/risc_zero/risc_zero.go b/operator/risc_zero/risc_zero.go index 92716ae02..3d1d08b39 100644 --- a/operator/risc_zero/risc_zero.go +++ b/operator/risc_zero/risc_zero.go @@ -7,22 +7,44 @@ package risc_zero #include "lib/risc_zero.h" */ import "C" -import ( - "unsafe" -) +import "unsafe" +import "fmt" + +func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) (isVerified bool, err error) { + // Here we define the return value on failure + isVerified = false + err = nil -func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) bool { if len(innerReceiptBuffer) == 0 || len(imageIdBuffer) == 0 { - return false + return isVerified, err } + // This will catch any go panic + defer func() { + rec := recover() + if rec != nil { + err = fmt.Errorf("Panic was caught while verifying risc0 proof: %s", rec) + } + }() + receiptPtr := (*C.uchar)(unsafe.Pointer(&innerReceiptBuffer[0])) imageIdPtr := (*C.uchar)(unsafe.Pointer(&imageIdBuffer[0])) + r := (C.int32_t)(0) + if len(publicInputBuffer) == 0 { // allow empty public input - return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0))) + r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0))) + } else { + publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0])) + r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer)))) } - publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0])) - return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer)))) + if r == -1 { + err = fmt.Errorf("Panic happened on FFI while verifying risc0 proof") + return isVerified, err + } + + isVerified = (r == 1) + + return isVerified, err } diff --git a/operator/risc_zero/risc_zero_test.go b/operator/risc_zero/risc_zero_test.go index cb83a06f5..d23a30c5b 100644 --- a/operator/risc_zero/risc_zero_test.go +++ b/operator/risc_zero/risc_zero_test.go @@ -22,8 +22,8 @@ func TestFibonacciRiscZeroProofVerifies(t *testing.T) { if err != nil { t.Errorf("could not open public input file: %s", err) } - - if !risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes) { + verified, err := risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes) + if err != nil || !verified { t.Errorf("proof did not verify") } } diff --git a/operator/sp1/lib/sp1.h b/operator/sp1/lib/sp1.h index 4a14ad959..aeb607b3a 100644 --- a/operator/sp1/lib/sp1.h +++ b/operator/sp1/lib/sp1.h @@ -1,5 +1,5 @@ #include #include -bool verify_sp1_proof_ffi(unsigned char *proof_buffer, uint32_t proof_len, +int32_t verify_sp1_proof_ffi(unsigned char *proof_buffer, uint32_t proof_len, unsigned char *elf_buffer, uint32_t elf_len); diff --git a/operator/sp1/lib/src/lib.rs b/operator/sp1/lib/src/lib.rs index c3145a00c..266d9e0be 100644 --- a/operator/sp1/lib/src/lib.rs +++ b/operator/sp1/lib/src/lib.rs @@ -6,8 +6,7 @@ lazy_static! { static ref PROVER_CLIENT: ProverClient = ProverClient::new(); } -#[no_mangle] -pub extern "C" fn verify_sp1_proof_ffi( +fn inner_verify_sp1_proof_ffi( proof_bytes: *const u8, proof_len: u32, elf_bytes: *const u8, @@ -35,6 +34,23 @@ pub extern "C" fn verify_sp1_proof_ffi( false } +#[no_mangle] +pub extern "C" fn verify_sp1_proof_ffi( + proof_bytes: *const u8, + proof_len: u32, + elf_bytes: *const u8, + elf_len: u32, +) -> i32 { + let result = std::panic::catch_unwind(|| { + inner_verify_sp1_proof_ffi(proof_bytes, proof_len, elf_bytes, elf_len) + }); + + match result { + Ok(v) => v as i32, + Err(_) => -1, + } +} + #[cfg(test)] mod tests { use super::*; @@ -49,7 +65,7 @@ mod tests { let result = verify_sp1_proof_ffi(proof_bytes, PROOF.len() as u32, elf_bytes, ELF.len() as u32); - assert!(result) + assert_eq!(result, 1) } #[test] @@ -63,6 +79,6 @@ mod tests { elf_bytes, ELF.len() as u32, ); - assert!(!result) + assert_eq!(result, 0) } } diff --git a/operator/sp1/sp1.go b/operator/sp1/sp1.go index 64b310844..121c1beb0 100644 --- a/operator/sp1/sp1.go +++ b/operator/sp1/sp1.go @@ -8,14 +8,35 @@ package sp1 */ import "C" import "unsafe" +import "fmt" -func VerifySp1Proof(proofBuffer []byte, elfBuffer []byte) bool { +func VerifySp1Proof(proofBuffer []byte, elfBuffer []byte) (isVerified bool, err error) { + // Here we define the return value on failure + isVerified = false + err = nil if len(proofBuffer) == 0 || len(elfBuffer) == 0 { - return false + return isVerified, err } + // This will catch any go panic + defer func() { + rec := recover() + if rec != nil { + err = fmt.Errorf("Panic was caught while verifying sp1 proof: %s", rec) + } + }() + proofPtr := (*C.uchar)(unsafe.Pointer(&proofBuffer[0])) elfPtr := (*C.uchar)(unsafe.Pointer(&elfBuffer[0])) - return (bool)(C.verify_sp1_proof_ffi(proofPtr, (C.uint32_t)(len(proofBuffer)), elfPtr, (C.uint32_t)(len(elfBuffer)))) + r := (C.int32_t)(C.verify_sp1_proof_ffi(proofPtr, (C.uint32_t)(len(proofBuffer)), elfPtr, (C.uint32_t)(len(elfBuffer)))) + + if r == -1 { + err = fmt.Errorf("Panic happened on FFI while verifying sp1 proof") + return isVerified, err + } + + isVerified = (r == 1) + + return isVerified, err } diff --git a/operator/sp1/sp1_test.go b/operator/sp1/sp1_test.go index d342751bd..a52425269 100644 --- a/operator/sp1/sp1_test.go +++ b/operator/sp1/sp1_test.go @@ -22,7 +22,8 @@ func TestFibonacciSp1ProofVerifies(t *testing.T) { t.Errorf("could not open elf file: %s", err) } - if !sp1.VerifySp1Proof(proofBytes, elfBytes) { + verified, err := sp1.VerifySp1Proof(proofBytes, elfBytes) + if err != nil || !verified { t.Errorf("proof did not verify") } }