diff --git a/starknet-core/src/crypto.rs b/starknet-core/src/crypto.rs index e703fc24..3b891ca9 100644 --- a/starknet-core/src/crypto.rs +++ b/starknet-core/src/crypto.rs @@ -98,6 +98,31 @@ pub fn ecdsa_verify( } } +/// MaskBits masks the specified (excess) bits in a byte slice. +/// +/// Parameters: +/// - mask: is an integer representing the number of bits to mask +/// - word_size: is an integer representing the number of bits in each element of the slice +/// - slice: is a byte slice on which the masking operation is performed +/// Makes the operation in place on slice, gets a new byte slice that contains the masked bits +/// Porting the MaskBits from https://github.com/NethermindEth/starknet.go/blob/main/curve/utils.go#L128-L143 +pub fn mask_bits(mask: usize, word_size: usize, slice: &mut [u8]) { + let mut excess = slice.len() * word_size - mask; + + for v in slice { + let by = v; + if excess > 0 { + if excess > word_size { + excess -= word_size; + continue; + } + *by <<= excess; + *by >>= excess; + excess = 0; + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -226,4 +251,23 @@ mod tests { assert!(!ecdsa_verify(&public_key, &message_hash, &Signature { r, s }).unwrap()); } + + #[test] + fn test_mask_bits() { + // Given + let mut to_mask: [u8; 32] = [ + 72, 232, 75, 188, 182, 142, 15, 124, 115, 169, 5, 139, 168, 43, 109, 169, 193, 255, + 220, 80, 46, 252, 240, 52, 231, 139, 12, 0, 60, 34, 236, 201, + ]; + + // When + mask_bits(250, 8, &mut to_mask[..]); + let expected: [u8; 32] = [ + 0, 232, 75, 188, 182, 142, 15, 124, 115, 169, 5, 139, 168, 43, 109, 169, 193, 255, 220, + 80, 46, 252, 240, 52, 231, 139, 12, 0, 60, 34, 236, 201, + ]; + + // Then + assert_eq!(to_mask, expected); + } }