Skip to content

Commit

Permalink
fix: add panic catch on operator calling FFI (#1196)
Browse files Browse the repository at this point in the history
  • Loading branch information
entropidelic authored Oct 18, 2024
2 parents e06df5c + 20d03fb commit 84fc022
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 42 deletions.
3 changes: 2 additions & 1 deletion operator/merkle_tree/lib/merkle_tree.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <stdbool.h>
#include <stdint.h>

bool verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root);
int32_t verify_merkle_tree_batch_ffi(unsigned char *batch_bytes, unsigned int batch_len, unsigned char *merkle_root);
25 changes: 20 additions & 5 deletions operator/merkle_tree/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use aligned_sdk::core::types::{
use lambdaworks_crypto::merkle_tree::merkle::MerkleTree;
use log::error;

#[no_mangle]
pub extern "C" fn verify_merkle_tree_batch_ffi(
fn inner_verify_merkle_tree_batch_ffi(
batch_ptr: *const u8,
batch_len: usize,
merkle_root: &[u8; 32],
Expand Down Expand Up @@ -53,6 +52,22 @@ pub extern "C" fn verify_merkle_tree_batch_ffi(
computed_batch_merkle_tree.root == *merkle_root
}

#[no_mangle]
pub extern "C" fn verify_merkle_tree_batch_ffi(
batch_ptr: *const u8,
batch_len: usize,
merkle_root: &[u8; 32],
) -> i32 {
let result = std::panic::catch_unwind(|| {
inner_verify_merkle_tree_batch_ffi(batch_ptr, batch_len, merkle_root)
});

match result {
Ok(v) => v as i32,
Err(_) => -1,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -75,7 +90,7 @@ mod tests {
let result =
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);

assert_eq!(result, true);
assert_eq!(result, 1);
}

#[test]
Expand All @@ -92,7 +107,7 @@ mod tests {
let result =
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);

assert_eq!(result, false);
assert_eq!(result, 0);
}

#[test]
Expand All @@ -109,6 +124,6 @@ mod tests {
let result =
verify_merkle_tree_batch_ffi(bytes_vec.as_ptr(), bytes_vec.len(), &merkle_root);

assert_eq!(result, false);
assert_eq!(result, 0);
}
}
28 changes: 25 additions & 3 deletions operator/merkle_tree/merkle_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,35 @@ package merkle_tree
*/
import "C"
import "unsafe"
import "fmt"

func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) bool {
func VerifyMerkleTreeBatch(batchBuffer []byte, merkleRootBuffer [32]byte) (isVerified bool, err error) {
// Here we define the return value on failure
isVerified = false
err = nil
if len(batchBuffer) == 0 {
return false
return isVerified, err
}

// This will catch any go panic
defer func() {
rec := recover()
if rec != nil {
err = fmt.Errorf("Panic was caught while verifying merkle tree batch: %s", rec)
}
}()

batchPtr := (*C.uchar)(unsafe.Pointer(&batchBuffer[0]))
merkleRootPtr := (*C.uchar)(unsafe.Pointer(&merkleRootBuffer[0]))
return (bool)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr))

r := (C.int32_t)(C.verify_merkle_tree_batch_ffi(batchPtr, (C.uint)(len(batchBuffer)), merkleRootPtr))

if r == -1 {
err = fmt.Errorf("Panic happened on FFI while verifying merkle tree batch")
return isVerified, err
}

isVerified = (r == 1)

return isVerified, err
}
3 changes: 2 additions & 1 deletion operator/merkle_tree/merkle_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func TestVerifyMerkleTreeBatch(t *testing.T) {
var merkleRoot [32]byte
copy(merkleRoot[:], merkle_root)

if !VerifyMerkleTreeBatch(batchByteValue, merkleRoot) {
verified, err := VerifyMerkleTreeBatch(batchByteValue, merkleRoot)
if err != nil || !verified {
t.Errorf("Batch did not verify Merkle Root")
}

Expand Down
18 changes: 14 additions & 4 deletions operator/pkg/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,13 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool)
results <- verificationResult

case common.SP1:
verificationResult := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode)
o.Logger.Infof("SP1 proof verification result: %t", verificationResult)
results <- verificationResult
verificationResult, err := sp1.VerifySp1Proof(verificationData.Proof, verificationData.VmProgramCode)
o.handleVerificationResult(results, verificationResult, err, "SP1 proof verification")

case common.Risc0:
verificationResult := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof,
verificationResult, err := risc_zero.VerifyRiscZeroReceipt(verificationData.Proof,
verificationData.VmProgramCode, verificationData.PubInput)
o.handleVerificationResult(results, verificationResult, err, "RiscZero proof verification")

o.Logger.Infof("Risc0 proof verification result: %t", verificationResult)
results <- verificationResult
Expand All @@ -512,6 +512,16 @@ func (o *Operator) verify(verificationData VerificationData, results chan bool)
}
}

func (o *Operator) handleVerificationResult(results chan bool, isVerified bool, err error, name string) {
if err != nil {
o.Logger.Errorf("%v failed %v", name, err)
results <- false
} else {
o.Logger.Infof("%v result: %t", name, isVerified)
results <- isVerified
}
}

// VerifyPlonkProofBLS12_381 verifies a PLONK proof using BLS12-381 curve.
func (o *Operator) verifyPlonkProofBLS12_381(proofBytes []byte, pubInputBytes []byte, verificationKeyBytes []byte) bool {
return o.verifyPlonkProof(proofBytes, pubInputBytes, verificationKeyBytes, ecc.BLS12_381)
Expand Down
6 changes: 3 additions & 3 deletions operator/pkg/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ func (o *Operator) getBatchFromDataService(ctx context.Context, batchURL string,

// Checks if downloaded merkle root is the same as the expected one
o.Logger.Infof("Verifying batch merkle tree...")
merkle_root_check := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot)
if !merkle_root_check {
return nil, fmt.Errorf("merkle root check failed")
merkle_root_check, err := merkle_tree.VerifyMerkleTreeBatch(batchBytes, expectedMerkleRoot)
if err != nil || !merkle_root_check {
return nil, fmt.Errorf("Error while verifying merkle tree batch")
}
o.Logger.Infof("Batch merkle tree verified")

Expand Down
2 changes: 1 addition & 1 deletion operator/risc_zero/lib/risc_zero.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <stdbool.h>
#include <stdint.h>

bool verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len);
int32_t verify_risc_zero_receipt_ffi(unsigned char *inner_receipt_bytes, uint32_t inner_receipt_len, unsigned char *image_id, uint32_t image_id_len, unsigned char *public_input, uint32_t public_input_len);
35 changes: 30 additions & 5 deletions operator/risc_zero/lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use log::error;
use risc0_zkvm::{InnerReceipt, Receipt};

#[no_mangle]
pub extern "C" fn verify_risc_zero_receipt_ffi(
fn inner_verify_risc_zero_receipt_ffi(
inner_receipt_bytes: *const u8,
inner_receipt_len: u32,
image_id: *const u8,
Expand Down Expand Up @@ -43,6 +42,32 @@ pub extern "C" fn verify_risc_zero_receipt_ffi(
false
}

#[no_mangle]
pub extern "C" fn verify_risc_zero_receipt_ffi(
inner_receipt_bytes: *const u8,
inner_receipt_len: u32,
image_id: *const u8,
image_id_len: u32,
public_input: *const u8,
public_input_len: u32,
) -> i32 {
let result = std::panic::catch_unwind(|| {
inner_verify_risc_zero_receipt_ffi(
inner_receipt_bytes,
inner_receipt_len,
image_id,
image_id_len,
public_input,
public_input_len,
)
});

match result {
Ok(v) => v as i32,
Err(_) => -1,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -69,7 +94,7 @@ mod tests {
public_input,
PUBLIC_INPUT.len() as u32,
);
assert!(result)
assert_eq!(result, 1)
}

#[test]
Expand All @@ -86,7 +111,7 @@ mod tests {
public_input,
PUBLIC_INPUT.len() as u32,
);
assert!(!result)
assert_eq!(result, 0)
}

#[test]
Expand All @@ -103,6 +128,6 @@ mod tests {
public_input,
0,
);
assert!(!result)
assert_eq!(result, 0)
}
}
38 changes: 30 additions & 8 deletions operator/risc_zero/risc_zero.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,44 @@ package risc_zero
#include "lib/risc_zero.h"
*/
import "C"
import (
"unsafe"
)
import "unsafe"
import "fmt"

func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) (isVerified bool, err error) {
// Here we define the return value on failure
isVerified = false
err = nil

func VerifyRiscZeroReceipt(innerReceiptBuffer []byte, imageIdBuffer []byte, publicInputBuffer []byte) bool {
if len(innerReceiptBuffer) == 0 || len(imageIdBuffer) == 0 {
return false
return isVerified, err
}

// This will catch any go panic
defer func() {
rec := recover()
if rec != nil {
err = fmt.Errorf("Panic was caught while verifying risc0 proof: %s", rec)
}
}()

receiptPtr := (*C.uchar)(unsafe.Pointer(&innerReceiptBuffer[0]))
imageIdPtr := (*C.uchar)(unsafe.Pointer(&imageIdBuffer[0]))

r := (C.int32_t)(0)

if len(publicInputBuffer) == 0 { // allow empty public input
return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0)))
r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), nil, (C.uint32_t)(0)))
} else {
publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0]))
r = (C.int32_t)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer))))
}

publicInputPtr := (*C.uchar)(unsafe.Pointer(&publicInputBuffer[0]))
return (bool)(C.verify_risc_zero_receipt_ffi(receiptPtr, (C.uint32_t)(len(innerReceiptBuffer)), imageIdPtr, (C.uint32_t)(len(imageIdBuffer)), publicInputPtr, (C.uint32_t)(len(publicInputBuffer))))
if r == -1 {
err = fmt.Errorf("Panic happened on FFI while verifying risc0 proof")
return isVerified, err
}

isVerified = (r == 1)

return isVerified, err
}
4 changes: 2 additions & 2 deletions operator/risc_zero/risc_zero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestFibonacciRiscZeroProofVerifies(t *testing.T) {
if err != nil {
t.Errorf("could not open public input file: %s", err)
}

if !risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes) {
verified, err := risc_zero.VerifyRiscZeroReceipt(innerReceiptBytes, imageIdBytes, publicInputBytes)
if err != nil || !verified {
t.Errorf("proof did not verify")
}
}
2 changes: 1 addition & 1 deletion operator/sp1/lib/sp1.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <stdbool.h>
#include <stdint.h>

bool verify_sp1_proof_ffi(unsigned char *proof_buffer, uint32_t proof_len,
int32_t verify_sp1_proof_ffi(unsigned char *proof_buffer, uint32_t proof_len,
unsigned char *elf_buffer, uint32_t elf_len);
24 changes: 20 additions & 4 deletions operator/sp1/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ lazy_static! {
static ref PROVER_CLIENT: ProverClient = ProverClient::new();
}

#[no_mangle]
pub extern "C" fn verify_sp1_proof_ffi(
fn inner_verify_sp1_proof_ffi(
proof_bytes: *const u8,
proof_len: u32,
elf_bytes: *const u8,
Expand Down Expand Up @@ -35,6 +34,23 @@ pub extern "C" fn verify_sp1_proof_ffi(
false
}

#[no_mangle]
pub extern "C" fn verify_sp1_proof_ffi(
proof_bytes: *const u8,
proof_len: u32,
elf_bytes: *const u8,
elf_len: u32,
) -> i32 {
let result = std::panic::catch_unwind(|| {
inner_verify_sp1_proof_ffi(proof_bytes, proof_len, elf_bytes, elf_len)
});

match result {
Ok(v) => v as i32,
Err(_) => -1,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -49,7 +65,7 @@ mod tests {

let result =
verify_sp1_proof_ffi(proof_bytes, PROOF.len() as u32, elf_bytes, ELF.len() as u32);
assert!(result)
assert_eq!(result, 1)
}

#[test]
Expand All @@ -63,6 +79,6 @@ mod tests {
elf_bytes,
ELF.len() as u32,
);
assert!(!result)
assert_eq!(result, 0)
}
}
27 changes: 24 additions & 3 deletions operator/sp1/sp1.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,35 @@ package sp1
*/
import "C"
import "unsafe"
import "fmt"

func VerifySp1Proof(proofBuffer []byte, elfBuffer []byte) bool {
func VerifySp1Proof(proofBuffer []byte, elfBuffer []byte) (isVerified bool, err error) {
// Here we define the return value on failure
isVerified = false
err = nil
if len(proofBuffer) == 0 || len(elfBuffer) == 0 {
return false
return isVerified, err
}

// This will catch any go panic
defer func() {
rec := recover()
if rec != nil {
err = fmt.Errorf("Panic was caught while verifying sp1 proof: %s", rec)
}
}()

proofPtr := (*C.uchar)(unsafe.Pointer(&proofBuffer[0]))
elfPtr := (*C.uchar)(unsafe.Pointer(&elfBuffer[0]))

return (bool)(C.verify_sp1_proof_ffi(proofPtr, (C.uint32_t)(len(proofBuffer)), elfPtr, (C.uint32_t)(len(elfBuffer))))
r := (C.int32_t)(C.verify_sp1_proof_ffi(proofPtr, (C.uint32_t)(len(proofBuffer)), elfPtr, (C.uint32_t)(len(elfBuffer))))

if r == -1 {
err = fmt.Errorf("Panic happened on FFI while verifying sp1 proof")
return isVerified, err
}

isVerified = (r == 1)

return isVerified, err
}
Loading

0 comments on commit 84fc022

Please sign in to comment.