Skip to content

Commit

Permalink
Search descritors with no package when no package was specified.
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-is-hate committed Jun 5, 2024
1 parent ce61d8f commit a8c9f46
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 60 deletions.
6 changes: 3 additions & 3 deletions src/mock_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,21 @@ impl Service<Request<hyper::Body>> for GrpcMockServer {
if let Some((service, method)) = request_path[1..].split_once('/') {
let service_name = last_name(service);
let lookup = format!("{service_name}/{method}");
if let Some((file, file_descriptor, method_descriptor, message)) = routes.get(lookup.as_str()) {
if let Some((file, _file_descriptor, method_descriptor, message)) = routes.get(lookup.as_str()) {
trace!(message = message.description.as_str(), "Found route for service call");
let file_descriptors: HashMap<String, &FileDescriptorProto> = file.file.iter().map(
|des| (des.name.clone().unwrap_or_default(), des)).collect();
let input_name = method_descriptor.input_type.as_ref().expect(format!(
"Input message name is empty for service {}/{}", service_name, method).as_str());
let (input_message_name, input_package_name) = split_name(input_name);
let input_message = find_message_descriptor(
input_message_name, input_package_name, file_descriptor, &file_descriptors);
input_message_name, input_package_name, &file_descriptors);

let output_name = method_descriptor.output_type.as_ref().expect(format!(
"Output message name is empty for service {}/{}", service_name, method).as_str());
let (output_message_name, output_package_name) = split_name(output_name);
let output_message = find_message_descriptor(
output_message_name, output_package_name, file_descriptor, &file_descriptors);
output_message_name, output_package_name, &file_descriptors);

if let Ok(input_message) = input_message {
if let Ok(output_message) = output_message {
Expand Down
15 changes: 7 additions & 8 deletions src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ fn construct_protobuf_interaction_for_service(
trace!(%input_name, ?input_package, input_message_name, "Input message");
trace!(%output_name, ?output_package, output_message_name, "Output message");

let request_descriptor = find_message_descriptor(input_message_name, input_package, file_descriptor, all_descriptors)?;
let response_descriptor = find_message_descriptor(output_message_name, output_package, file_descriptor, all_descriptors)?;
let request_descriptor = find_message_descriptor(input_message_name, input_package, all_descriptors)?;
let response_descriptor = find_message_descriptor(output_message_name, output_package, all_descriptors)?;

trace!("request_descriptor = {:?}", request_descriptor);
trace!("response_descriptor = {:?}", response_descriptor);
Expand Down Expand Up @@ -636,7 +636,7 @@ fn build_single_embedded_field_value(
debug!("Configuring the message from config {:?}", config);
let (message_name, package_name) = split_name(type_name.as_str());
let embedded_type = find_nested_type(&message_builder.descriptor, field_descriptor)
.or_else(|| find_message_descriptor(message_name, package_name, &message_builder.file_descriptor, all_descriptors).ok())
.or_else(|| find_message_descriptor(message_name, package_name, all_descriptors).ok())
.ok_or_else(|| anyhow!("Did not find message '{}' in the current message or in the file descriptors", type_name))?;
let mut embedded_builder = MessageBuilder::new(&embedded_type, message_name, &message_builder.file_descriptor);

Expand Down Expand Up @@ -2734,15 +2734,14 @@ pub(crate) mod tests {
let all: HashMap<String, &FileDescriptorProto> = fds.file
.iter().map(|des| (des.name.clone().unwrap_or_default(), des))
.collect();
let file_descriptor = &fds.file[0];

let result = super::find_message_descriptor("RectangleLocationRequest", None, file_descriptor, &all).unwrap();
let result = super::find_message_descriptor("RectangleLocationRequest", None, &all).unwrap();
expect!(result.field.len()).to(be_equal_to(2));
let result = super::find_message_descriptor("RectangleLocationRequest", Some("primary"), file_descriptor, &all).unwrap();
let result = super::find_message_descriptor("RectangleLocationRequest", Some("primary"), &all).unwrap();
expect!(result.field.len()).to(be_equal_to(4));
let result = super::find_message_descriptor("RectangleLocationRequest", Some(".primary"), file_descriptor, &all).unwrap();
let result = super::find_message_descriptor("RectangleLocationRequest", Some(".primary"), &all).unwrap();
expect!(result.field.len()).to(be_equal_to(4));
let result = super::find_message_descriptor("RectangleLocationRequest", Some("imported"), file_descriptor, &all).unwrap();
let result = super::find_message_descriptor("RectangleLocationRequest", Some("imported"), &all).unwrap();
expect!(result.field.len()).to(be_equal_to(2));
}
}
131 changes: 82 additions & 49 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,31 +82,32 @@ pub fn find_message_type_in_file_descriptor(
message_name, descriptor.name.as_deref().unwrap_or("unknown")))
}

/// Very similar to find_message_type_by_name, but it narrows down a list of file descriptors to the ones where package
/// name matches the given package name, either the one provided explicitly, or the one from the file_descriptor.
/// Finds message descriptor in the map of all file descriptors. If the package is provided, it will
/// search only the descriptors matching the package. If not, it will search all descriptors with no package specified.
/// (because package is an optional field in proto3)
pub fn find_message_descriptor(
message_name: &str,
package: Option<&str>,
file_descriptor: &FileDescriptorProto,
all_descriptors: &HashMap<String, &FileDescriptorProto>,
) -> anyhow::Result<DescriptorProto> {
let package = package
.filter(|p| !p.is_empty())
.unwrap_or(file_descriptor.package.as_deref().unwrap_or_default());
debug!(
"Looking for message '{}' in package '{}'",
message_name, package
);
find_all_file_descriptors_for_package(package, all_descriptors)?
.iter()
.find_map(|fd| find_message_type_in_file_descriptor(message_name, fd).ok())
.ok_or_else(|| {
anyhow!(
"Did not find a message type '{}' in any of the file descriptors for package '{}'",
message_name,
package,
)
})
let descriptors;
if let Some(package) = package {
debug!(
"Looking for message '{}' in package '{}'",
message_name, package
);
descriptors = find_all_file_descriptors_for_package(package, all_descriptors)?;
} else {
descriptors = find_all_file_descriptors_with_no_package(all_descriptors)?;
}
descriptors.iter()
.find_map(|fd| find_message_type_in_file_descriptor(message_name, fd).ok())
.ok_or_else(|| {
anyhow!(
"Did not find a message type '{}' in any of the file descriptors '{:?}'",
message_name,
descriptors.iter().map(|d| d.name.clone().unwrap_or_default()).collect::<Vec<_>>())
})
}

fn find_all_file_descriptors_for_package<'a>(
Expand All @@ -118,27 +119,45 @@ fn find_all_file_descriptors_for_package<'a>(
} else {
package
};
debug!("Looking for file descriptors for package '{}'", package);
let found = all_descriptors
.values()
.filter(|descriptor| {
debug!("Checking file descriptor '{:?}' with package '{:?}'", descriptor.name, descriptor.package);
if let Some(descriptor_package) = &descriptor.package {
descriptor_package == package
} else {
false
}
})
.cloned()
.collect::<Vec<_>>();
let found = filter_file_descriptors(all_descriptors, |descriptor| {
debug!("Checking file descriptor '{:?}' with package '{:?}'", descriptor.name, descriptor.package);
if let Some(descriptor_package) = &descriptor.package {
descriptor_package == package
} else {
false
}
});
if found.is_empty() {
Err(anyhow!("Did not find a file descriptor with package '{}'", package))
Err(anyhow!("Did not find a file descriptor for a package '{}'", package))
} else {
debug!("Found {} file descriptors for package '{}'", found.len(), package);
Ok(found)
}
}

fn find_all_file_descriptors_with_no_package<'a>(
all_descriptors: &'a HashMap<String, &FileDescriptorProto>
) -> anyhow::Result<Vec<&'a FileDescriptorProto>> {
let found = filter_file_descriptors(all_descriptors, |d| d.package.is_none());
if found.is_empty() {
Err(anyhow!("Did not find any file descriptors with no package specified"))
} else {
debug!("Found {} file descriptors with no package", found.len());
Ok(found)
}
}

fn filter_file_descriptors<'a, F>(all_descriptors: &'a HashMap<String, &FileDescriptorProto>, filter: F
) -> Vec<&'a FileDescriptorProto>
where
F: FnMut(&&&FileDescriptorProto) -> bool,
{
all_descriptors.values()
.filter(filter)
.cloned()
.collect::<Vec<_>>()
}

/// If the field is a map field. A field will be a map field if it is a repeated field, the field
/// type is a message and the nested type has the map flag set on the message options.
pub fn is_map_field(message_descriptor: &DescriptorProto, field: &FieldDescriptorProto) -> bool {
Expand Down Expand Up @@ -565,7 +584,8 @@ pub(crate) fn prost_string<S: Into<String>>(s: S) -> Value {

#[cfg(test)]
pub(crate) mod tests {
use bytes::Bytes;
use anyhow::anyhow;
use bytes::Bytes;
use expectest::prelude::*;
use maplit::hashmap;
use prost::Message;
Expand Down Expand Up @@ -923,7 +943,7 @@ pub(crate) mod tests {
],
.. FileDescriptorProto::default()
};
let request = FileDescriptorProto {
let request: FileDescriptorProto = FileDescriptorProto {
name: Some("request.proto".to_string()),
package: Some("service".to_string()),
message_type: vec![
Expand Down Expand Up @@ -953,9 +973,8 @@ pub(crate) mod tests {
],
.. FileDescriptorProto::default()
};
let request_diff_package = FileDescriptorProto {
let request_no_package = FileDescriptorProto {
name: Some("request_diff_package.proto".to_string()),
package: Some("diff".to_string()),
message_type: vec![
DescriptorProto {
name: Some("Request".to_string()),
Expand All @@ -973,23 +992,37 @@ pub(crate) mod tests {
.. FileDescriptorProto::default()
};
let all_descriptors = &hashmap!{
"service.proto".to_string() => &service,
"request.proto".to_string() => &request,
"response.proto".to_string() => &response,
"request_diff_package.proto".to_string() => &request_diff_package
"request_no_package.proto".to_string() => &request_no_package
};
// use default package name from the service descriptor
let result = find_message_descriptor("Request", None, &service, all_descriptors);
expect!(result.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"field"));

// explicitly provide package name
let result_explicit_pkg: Result<DescriptorProto, anyhow::Error> = find_message_descriptor("Request", Some("diff"), &service, all_descriptors);
expect!(result_explicit_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"bool_field"));
let result_explicit_pkg: Result<DescriptorProto, anyhow::Error> = find_message_descriptor("Request", Some("service"), all_descriptors);
expect!(result_explicit_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"field"));

// no package provided means search descriptors without packages only
let result_no_pkg: Result<DescriptorProto, anyhow::Error> = find_message_descriptor("Request", None, all_descriptors);
expect!(result_no_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"bool_field"));

// message not found error
let result_err: Result<DescriptorProto, anyhow::Error> = find_message_descriptor("Response", Some("diff"), &service, all_descriptors);
let result_err: Result<DescriptorProto, anyhow::Error> = find_message_descriptor("Missing", Some("service"), all_descriptors);
expect!(result_err.as_ref()).to(be_err());
expect!(result_err.unwrap_err().to_string()).to(
be_equal_to("Did not find a message type 'Response' in any of the file descriptors for package 'diff'")
);
expect!(result_err.unwrap_err().to_string()
.starts_with("Did not find a message type 'Missing' in any of the file descriptors"))
.to(be_true());

// file descriptor not found for package
let result_err_no_pkg: Result<DescriptorProto, anyhow::Error> = find_message_descriptor("Request", Some("missing"), all_descriptors);
expect!(result_err_no_pkg.as_ref()).to(be_err());
// "Did not find a file descriptor for package 'missing'"
expect!(result_err_no_pkg.unwrap_err().to_string())
.to(be_equal_to("Did not find a file descriptor for a package 'missing'"));

// no descriptors found without a package
let result_err_no_pkg: Result<DescriptorProto, anyhow::Error> = find_message_descriptor("Request", None, &hashmap!{});
expect!(result_err_no_pkg.as_ref()).to(be_err());
expect!(result_err_no_pkg.unwrap_err().to_string())
.to(be_equal_to("Did not find any file descriptors with no package specified"));
}
}

0 comments on commit a8c9f46

Please sign in to comment.