Skip to content

Commit

Permalink
fix: Repeated enum fields must be encoded as packed varints #27
Browse files Browse the repository at this point in the history
  • Loading branch information
rholshausen committed Jan 19, 2024
1 parent 9caef2a commit 884b36c
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# will have compiled files and executables
/target/
target/

# These are backup files generated by rustfmt
**/*.rs.bk
Expand Down
67 changes: 62 additions & 5 deletions src/message_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use prost_types::{DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, Fi
use prost_types::field_descriptor_proto::Type;
use tracing::{trace, warn};

use crate::utils::{last_name, should_be_packed_type};
use crate::utils::{last_name, should_be_packed_type, display_bytes};

/// Enum to set what type of field the value is for
#[derive(Clone, Copy, Debug, PartialEq)]
Expand Down Expand Up @@ -133,9 +133,9 @@ impl MessageBuilder {
}
}

trace!("encode_message: {} bytes", buffer.len());

Ok(buffer.freeze())
let bytes = buffer.freeze();
trace!("encode_message: {} bytes {}", bytes.len(), display_bytes(&bytes));
Ok(bytes)
}

fn encode_single_field(&self, mut buffer: &mut BytesMut, field_data: &FieldValueInner, value: Option<MessageFieldValue>) -> anyhow::Result<()> {
Expand Down Expand Up @@ -329,6 +329,7 @@ impl MessageBuilder {
buffer: &mut BytesMut,
field_value: &FieldValueInner
) -> anyhow::Result<()> {
trace!(">> encode_packed_field({:?})", field_value);
if let Some(tag) = field_value.descriptor.number {
match field_value.proto_type {
Type::Double => {
Expand Down Expand Up @@ -366,6 +367,16 @@ impl MessageBuilder {
prost::encoding::int32::encode_packed(tag as u32, &values, buffer);
Ok(())
}
Type::Enum => {
let values = field_value.values.iter()
.map(|v| match &v.rtype {
RType::Enum(i, _) => *i,
_ => v.rtype.as_i32().unwrap_or_default()
})
.collect::<Vec<i32>>();
prost::encoding::int32::encode_packed(tag as u32, &values, buffer);
Ok(())
}
Type::Fixed64 => {
let values = field_value.values.iter()
.map(|v| v.rtype.as_u64().unwrap_or_default())
Expand Down Expand Up @@ -708,7 +719,7 @@ impl MessageFieldValue {
}

#[cfg(test)]
mod tests {
pub(crate) mod tests {
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use bytes::{Bytes, BytesMut};
Expand Down Expand Up @@ -743,6 +754,7 @@ mod tests {
use crate::message_builder::MessageFieldValueType::Repeated;
use crate::message_decoder::{decode_message, ProtobufFieldData};
use crate::protobuf::tests::DESCRIPTOR_WITH_ENUM_BYTES;
use crate::utils::find_enum_by_name_in_message;

const ENCODED_MESSAGE: &str = "CuIFChxnb29nbGUvcHJvdG9idWYvc3RydWN0LnByb3RvEg9nb29nbGUucHJv\
dG9idWYimAEKBlN0cnVjdBI7CgZmaWVsZHMYASADKAsyIy5nb29nbGUucHJvdG9idWYuU3RydWN0LkZpZWxkc0VudHJ5\
Expand Down Expand Up @@ -1732,6 +1744,51 @@ mod tests {
expect!(result.to_vec()).to(be_equal_to(expected));
}

pub(crate) const REPEATED_ENUM_DESCRIPTORS: &str = "Cv4EChNyZXBlYXRlZF9lbnVtLnByb3RvEglwYWN0aXNzdWUieQoTQn\
Jva2VuU2FtcGxlUmVxdWVzdBI3CgR0eXBlGAEgAygOMiMucGFjdGlzc3VlLkJyb2tlblNhbXBsZVJlcXVlc3QuVHlwZVIEd\
HlwZSIpCgRUeXBlEgsKB1VOS05PV04QABIJCgVUWVBFMRABEgkKBVRZUEUyEAIiJgoUQnJva2VuU2FtcGxlUmVzcG9uc2US\
DgoCb2sYASABKAhSAm9rInsKFFdvcmtpbmdTYW1wbGVSZXF1ZXN0EjgKBHR5cGUYASABKA4yJC5wYWN0aXNzdWUuV29ya2l\
uZ1NhbXBsZVJlcXVlc3QuVHlwZVIEdHlwZSIpCgRUeXBlEgsKB1VOS05PV04QABIJCgVUWVBFMRABEgkKBVRZUEUyEAIiJw\
oVV29ya2luZ1NhbXBsZVJlc3BvbnNlEg4KAm9rGAEgASgIUgJvazJlChNCcm9rZW5TYW1wbGVTZXJ2aWNlEk4KCUdldFNhb\
XBsZRIeLnBhY3Rpc3N1ZS5Ccm9rZW5TYW1wbGVSZXF1ZXN0Gh8ucGFjdGlzc3VlLkJyb2tlblNhbXBsZVJlc3BvbnNlIgAya\
AoUV29ya2luZ1NhbXBsZVNlcnZpY2USUAoJR2V0U2FtcGxlEh8ucGFjdGlzc3VlLldvcmtpbmdTYW1wbGVSZXF1ZXN0GiAuc\
GFjdGlzc3VlLldvcmtpbmdTYW1wbGVSZXNwb25zZSIAQjpaOGdpdGh1Yi5jb20vc3Rhbi1pcy1oYXRlL3BhY3QtcHJvdG8ta\
XNzdWUtZGVtby87cGFjdGlzc3VlYgZwcm90bzM=";

#[test_log::test]
fn repeated_enum_fields_must_be_packed() {
let file_descriptor = get_file_descriptor("repeated_enum.proto", REPEATED_ENUM_DESCRIPTORS).unwrap();
let request_descriptor = file_descriptor.message_type.iter()
.find(|desc| desc.name.clone().unwrap_or_default() == "BrokenSampleRequest")
.unwrap();
let values_field_descriptor = request_descriptor.field.iter()
.find(|desc| desc.name.clone().unwrap_or_default() == "type")
.unwrap();
let mut builder = MessageBuilder::new(request_descriptor, "BrokenSampleRequest", &file_descriptor);
let enum_proto = find_enum_by_name_in_message(&request_descriptor.enum_type, "Type").unwrap();
let message_field_value = MessageFieldValue {
name: "type".to_string(),
raw_value: Some("Type2".to_string()),
rtype: RType::Enum(2, enum_proto.clone())
};
let message_field_value2 = MessageFieldValue {
name: "type".to_string(),
raw_value: Some("Type1".to_string()),
rtype: RType::Enum(1, enum_proto.clone())
};
builder.add_repeated_field_value(values_field_descriptor, "type", message_field_value);
builder.add_repeated_field_value(values_field_descriptor, "type", message_field_value2);

let expected = vec![
10, // Field 1, VARINT
2, // 2 bytes
2, // Enum 2 (Type2)
1 // Enum 1 (Type1)
];
let result = builder.encode_message().unwrap();
expect!(result.to_vec()).to(be_equal_to(expected));
}

#[test_log::test]
fn test_field_with_global_enum() {
let bytes: &[u8] = &DESCRIPTOR_WITH_ENUM_BYTES;
Expand Down
66 changes: 59 additions & 7 deletions src/message_decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,7 @@ pub fn decode_message<B>(
Type::Bool => vec![ (ProtobufFieldData::Boolean(varint > 0), wire_type) ],
Type::Uint32 => vec![ (ProtobufFieldData::UInteger32(varint as u32), wire_type) ],
Type::Enum => {
let enum_type_name = field_descriptor.type_name.clone().unwrap_or_default();
let enum_proto = find_enum_by_name_in_message(&descriptor.enum_type, enum_type_name.as_str())
.or_else(|| find_enum_by_name(descriptors, enum_type_name.as_str()))
.ok_or_else(|| anyhow!("Did not find the enum {} for the field {} in the Protobuf descriptor", enum_type_name, field_num))?;
vec![ (ProtobufFieldData::Enum(varint as i32, enum_proto.clone()), wire_type) ]
vec![ (decode_enum(descriptor, descriptors, &field_descriptor, varint)?, wire_type) ]
},
Type::Sint32 => {
let value = varint as u32;
Expand Down Expand Up @@ -333,7 +329,7 @@ pub fn decode_message<B>(
Type::Bytes => vec![ (ProtobufFieldData::Bytes(data_buffer.to_vec()), wire_type) ],
_ => if should_be_packed_type(t) && is_repeated_field(&field_descriptor) {
debug!("Reading length delimited field as a packed repeated field");
decode_packed_field(field_descriptor, &mut data_buffer)?
decode_packed_field(field_descriptor, descriptor, descriptors, &mut data_buffer)?
} else {
error!("Was expecting {:?} but received an unknown length-delimited type", t);
let mut buf = BytesMut::with_capacity((data_length + 8) as usize);
Expand Down Expand Up @@ -397,7 +393,25 @@ pub fn decode_message<B>(
Ok(fields.iter().sorted_by(|a, b| Ord::cmp(&a.field_num, &b.field_num)).cloned().collect())
}

fn decode_packed_field(field: FieldDescriptorProto, data: &mut Bytes) -> anyhow::Result<Vec<(ProtobufFieldData, WireType)>> {
fn decode_enum(
descriptor: &DescriptorProto,
descriptors: &FileDescriptorSet,
field_descriptor: &FieldDescriptorProto,
varint: u64
) -> anyhow::Result<ProtobufFieldData> {
let enum_type_name = field_descriptor.type_name.clone().unwrap_or_default();
let enum_proto = find_enum_by_name_in_message(&descriptor.enum_type, enum_type_name.as_str())
.or_else(|| find_enum_by_name(descriptors, enum_type_name.as_str()))
.ok_or_else(|| anyhow!("Did not find the enum {} for the field in the Protobuf descriptor", enum_type_name))?;
Ok(ProtobufFieldData::Enum(varint as i32, enum_proto.clone()))
}

fn decode_packed_field(
field: FieldDescriptorProto,
descriptor: &DescriptorProto,
descriptors: &FileDescriptorSet,
data: &mut Bytes
) -> anyhow::Result<Vec<(ProtobufFieldData, WireType)>> {
let mut values = vec![];
let t: Type = field.r#type();
match t {
Expand Down Expand Up @@ -429,6 +443,13 @@ fn decode_packed_field(field: FieldDescriptorProto, data: &mut Bytes) -> anyhow:
values.push((ProtobufFieldData::Integer32(varint as i32), WireType::Varint));
}
}
Type::Enum => {
while data.remaining() > 0 {
let varint = decode_varint(data)?;
let enum_value = decode_enum(descriptor, descriptors, &field, varint)?;
values.push((enum_value, WireType::Varint));
}
}
Type::Fixed64 => {
while data.remaining() >= mem::size_of::<u64>() {
values.push((ProtobufFieldData::UInteger64(data.get_u64_le()), WireType::SixtyFourBit));
Expand Down Expand Up @@ -492,6 +513,8 @@ fn find_field_descriptor(field_num: i32, descriptor: &DescriptorProto) -> anyhow

#[cfg(test)]
mod tests {
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use bytes::{BufMut, Bytes, BytesMut};
use expectest::prelude::*;
use pact_plugin_driver::proto::InitPluginRequest;
Expand All @@ -514,6 +537,7 @@ mod tests {
};
use crate::message_decoder::{decode_message, ProtobufFieldData};
use crate::protobuf::tests::DESCRIPTOR_WITH_ENUM_BYTES;
use crate::message_builder::tests::REPEATED_ENUM_DESCRIPTORS;

const FIELD_1_MESSAGE: [u8; 2] = [8, 1];
const FIELD_2_MESSAGE: [u8; 2] = [16, 55];
Expand Down Expand Up @@ -1207,4 +1231,32 @@ mod tests {
expect!(field_result.wire_type).to(be_equal_to(WireType::Varint));
expect!(&field_result.data).to(be_equal_to(&ProtobufFieldData::Enum(1, enum_proto.clone())));
}

#[test_log::test]
fn decode_message_with_repeated_enum_field() {
let bytes = BASE64.decode(REPEATED_ENUM_DESCRIPTORS).unwrap();
let buffer = Bytes::from(bytes);
let fds: FileDescriptorSet = FileDescriptorSet::decode(buffer).unwrap();
let main_descriptor = fds.file.iter()
.find(|fd| fd.name.clone().unwrap_or_default() == "repeated_enum.proto")
.unwrap();
let message_descriptor = main_descriptor.message_type.iter()
.find(|md| md.name.clone().unwrap_or_default() == "BrokenSampleRequest").unwrap();
let enum_proto = message_descriptor.enum_type.first().unwrap();

let message_bytes: &[u8] = &[10, 3, 2, 0, 1];
let mut buffer = Bytes::from(message_bytes);
let result = decode_message(&mut buffer, &message_descriptor, &fds).unwrap();
expect!(result.len()).to(be_equal_to(3));

expect!(result[0].field_num).to(be_equal_to(1));
expect!(result[0].wire_type).to(be_equal_to(WireType::Varint));
expect!(&result[0].data).to(be_equal_to(&ProtobufFieldData::Enum(2, enum_proto.clone())));
expect!(result[1].field_num).to(be_equal_to(1));
expect!(result[1].wire_type).to(be_equal_to(WireType::Varint));
expect!(&result[1].data).to(be_equal_to(&ProtobufFieldData::Enum(0, enum_proto.clone())));
expect!(result[2].field_num).to(be_equal_to(1));
expect!(result[2].wire_type).to(be_equal_to(WireType::Varint));
expect!(&result[2].data).to(be_equal_to(&ProtobufFieldData::Enum(1, enum_proto.clone())));
}
}
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ pub(crate) fn find_service_descriptor<'a>(
pub fn should_be_packed_type(field_type: Type) -> bool {
matches!(field_type, Type::Double | Type::Float | Type::Int64 | Type::Uint64 | Type::Int32 | Type::Fixed64 |
Type::Fixed32 | Type::Uint32 | Type::Sfixed32 | Type::Sfixed64 | Type::Sint32 |
Type::Sint64)
Type::Sint64 | Type::Enum)
}

/// Tries to convert a Protobuf Value to a Map. Returns an error if the incoming value is not a
Expand Down

0 comments on commit 884b36c

Please sign in to comment.