From a8c9f467701ccb1968f454fcfcc55413d995d2ec Mon Sep 17 00:00:00 2001 From: Stan Date: Tue, 4 Jun 2024 23:51:27 -0700 Subject: [PATCH] Search descritors with no package when no package was specified. --- src/mock_server.rs | 6 +-- src/protobuf.rs | 15 +++--- src/utils.rs | 131 ++++++++++++++++++++++++++++----------------- 3 files changed, 92 insertions(+), 60 deletions(-) diff --git a/src/mock_server.rs b/src/mock_server.rs index 50a9433..e4eb456 100644 --- a/src/mock_server.rs +++ b/src/mock_server.rs @@ -241,7 +241,7 @@ impl Service> for GrpcMockServer { if let Some((service, method)) = request_path[1..].split_once('/') { let service_name = last_name(service); let lookup = format!("{service_name}/{method}"); - if let Some((file, file_descriptor, method_descriptor, message)) = routes.get(lookup.as_str()) { + if let Some((file, _file_descriptor, method_descriptor, message)) = routes.get(lookup.as_str()) { trace!(message = message.description.as_str(), "Found route for service call"); let file_descriptors: HashMap = file.file.iter().map( |des| (des.name.clone().unwrap_or_default(), des)).collect(); @@ -249,13 +249,13 @@ impl Service> for GrpcMockServer { "Input message name is empty for service {}/{}", service_name, method).as_str()); let (input_message_name, input_package_name) = split_name(input_name); let input_message = find_message_descriptor( - input_message_name, input_package_name, file_descriptor, &file_descriptors); + input_message_name, input_package_name, &file_descriptors); let output_name = method_descriptor.output_type.as_ref().expect(format!( "Output message name is empty for service {}/{}", service_name, method).as_str()); let (output_message_name, output_package_name) = split_name(output_name); let output_message = find_message_descriptor( - output_message_name, output_package_name, file_descriptor, &file_descriptors); + output_message_name, output_package_name, &file_descriptors); if let Ok(input_message) = input_message { if let Ok(output_message) = output_message { diff --git a/src/protobuf.rs b/src/protobuf.rs index d53a8b7..5d7fe2b 100644 --- a/src/protobuf.rs +++ b/src/protobuf.rs @@ -189,8 +189,8 @@ fn construct_protobuf_interaction_for_service( trace!(%input_name, ?input_package, input_message_name, "Input message"); trace!(%output_name, ?output_package, output_message_name, "Output message"); - let request_descriptor = find_message_descriptor(input_message_name, input_package, file_descriptor, all_descriptors)?; - let response_descriptor = find_message_descriptor(output_message_name, output_package, file_descriptor, all_descriptors)?; + let request_descriptor = find_message_descriptor(input_message_name, input_package, all_descriptors)?; + let response_descriptor = find_message_descriptor(output_message_name, output_package, all_descriptors)?; trace!("request_descriptor = {:?}", request_descriptor); trace!("response_descriptor = {:?}", response_descriptor); @@ -636,7 +636,7 @@ fn build_single_embedded_field_value( debug!("Configuring the message from config {:?}", config); let (message_name, package_name) = split_name(type_name.as_str()); let embedded_type = find_nested_type(&message_builder.descriptor, field_descriptor) - .or_else(|| find_message_descriptor(message_name, package_name, &message_builder.file_descriptor, all_descriptors).ok()) + .or_else(|| find_message_descriptor(message_name, package_name, all_descriptors).ok()) .ok_or_else(|| anyhow!("Did not find message '{}' in the current message or in the file descriptors", type_name))?; let mut embedded_builder = MessageBuilder::new(&embedded_type, message_name, &message_builder.file_descriptor); @@ -2734,15 +2734,14 @@ pub(crate) mod tests { let all: HashMap = fds.file .iter().map(|des| (des.name.clone().unwrap_or_default(), des)) .collect(); - let file_descriptor = &fds.file[0]; - let result = super::find_message_descriptor("RectangleLocationRequest", None, file_descriptor, &all).unwrap(); + let result = super::find_message_descriptor("RectangleLocationRequest", None, &all).unwrap(); expect!(result.field.len()).to(be_equal_to(2)); - let result = super::find_message_descriptor("RectangleLocationRequest", Some("primary"), file_descriptor, &all).unwrap(); + let result = super::find_message_descriptor("RectangleLocationRequest", Some("primary"), &all).unwrap(); expect!(result.field.len()).to(be_equal_to(4)); - let result = super::find_message_descriptor("RectangleLocationRequest", Some(".primary"), file_descriptor, &all).unwrap(); + let result = super::find_message_descriptor("RectangleLocationRequest", Some(".primary"), &all).unwrap(); expect!(result.field.len()).to(be_equal_to(4)); - let result = super::find_message_descriptor("RectangleLocationRequest", Some("imported"), file_descriptor, &all).unwrap(); + let result = super::find_message_descriptor("RectangleLocationRequest", Some("imported"), &all).unwrap(); expect!(result.field.len()).to(be_equal_to(2)); } } diff --git a/src/utils.rs b/src/utils.rs index 74c8ca4..e94135f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -82,31 +82,32 @@ pub fn find_message_type_in_file_descriptor( message_name, descriptor.name.as_deref().unwrap_or("unknown"))) } -/// Very similar to find_message_type_by_name, but it narrows down a list of file descriptors to the ones where package -/// name matches the given package name, either the one provided explicitly, or the one from the file_descriptor. +/// Finds message descriptor in the map of all file descriptors. If the package is provided, it will +/// search only the descriptors matching the package. If not, it will search all descriptors with no package specified. +/// (because package is an optional field in proto3) pub fn find_message_descriptor( message_name: &str, package: Option<&str>, - file_descriptor: &FileDescriptorProto, all_descriptors: &HashMap, ) -> anyhow::Result { - let package = package - .filter(|p| !p.is_empty()) - .unwrap_or(file_descriptor.package.as_deref().unwrap_or_default()); - debug!( - "Looking for message '{}' in package '{}'", - message_name, package - ); - find_all_file_descriptors_for_package(package, all_descriptors)? - .iter() - .find_map(|fd| find_message_type_in_file_descriptor(message_name, fd).ok()) - .ok_or_else(|| { - anyhow!( - "Did not find a message type '{}' in any of the file descriptors for package '{}'", - message_name, - package, - ) - }) + let descriptors; + if let Some(package) = package { + debug!( + "Looking for message '{}' in package '{}'", + message_name, package + ); + descriptors = find_all_file_descriptors_for_package(package, all_descriptors)?; + } else { + descriptors = find_all_file_descriptors_with_no_package(all_descriptors)?; + } + descriptors.iter() + .find_map(|fd| find_message_type_in_file_descriptor(message_name, fd).ok()) + .ok_or_else(|| { + anyhow!( + "Did not find a message type '{}' in any of the file descriptors '{:?}'", + message_name, + descriptors.iter().map(|d| d.name.clone().unwrap_or_default()).collect::>()) + }) } fn find_all_file_descriptors_for_package<'a>( @@ -118,27 +119,45 @@ fn find_all_file_descriptors_for_package<'a>( } else { package }; - debug!("Looking for file descriptors for package '{}'", package); - let found = all_descriptors - .values() - .filter(|descriptor| { - debug!("Checking file descriptor '{:?}' with package '{:?}'", descriptor.name, descriptor.package); - if let Some(descriptor_package) = &descriptor.package { - descriptor_package == package - } else { - false - } - }) - .cloned() - .collect::>(); + let found = filter_file_descriptors(all_descriptors, |descriptor| { + debug!("Checking file descriptor '{:?}' with package '{:?}'", descriptor.name, descriptor.package); + if let Some(descriptor_package) = &descriptor.package { + descriptor_package == package + } else { + false + } + }); if found.is_empty() { - Err(anyhow!("Did not find a file descriptor with package '{}'", package)) + Err(anyhow!("Did not find a file descriptor for a package '{}'", package)) } else { debug!("Found {} file descriptors for package '{}'", found.len(), package); Ok(found) } } +fn find_all_file_descriptors_with_no_package<'a>( + all_descriptors: &'a HashMap +) -> anyhow::Result> { + let found = filter_file_descriptors(all_descriptors, |d| d.package.is_none()); + if found.is_empty() { + Err(anyhow!("Did not find any file descriptors with no package specified")) + } else { + debug!("Found {} file descriptors with no package", found.len()); + Ok(found) + } +} + +fn filter_file_descriptors<'a, F>(all_descriptors: &'a HashMap, filter: F +) -> Vec<&'a FileDescriptorProto> +where + F: FnMut(&&&FileDescriptorProto) -> bool, +{ + all_descriptors.values() + .filter(filter) + .cloned() + .collect::>() +} + /// If the field is a map field. A field will be a map field if it is a repeated field, the field /// type is a message and the nested type has the map flag set on the message options. pub fn is_map_field(message_descriptor: &DescriptorProto, field: &FieldDescriptorProto) -> bool { @@ -565,7 +584,8 @@ pub(crate) fn prost_string>(s: S) -> Value { #[cfg(test)] pub(crate) mod tests { - use bytes::Bytes; + use anyhow::anyhow; +use bytes::Bytes; use expectest::prelude::*; use maplit::hashmap; use prost::Message; @@ -923,7 +943,7 @@ pub(crate) mod tests { ], .. FileDescriptorProto::default() }; - let request = FileDescriptorProto { + let request: FileDescriptorProto = FileDescriptorProto { name: Some("request.proto".to_string()), package: Some("service".to_string()), message_type: vec![ @@ -953,9 +973,8 @@ pub(crate) mod tests { ], .. FileDescriptorProto::default() }; - let request_diff_package = FileDescriptorProto { + let request_no_package = FileDescriptorProto { name: Some("request_diff_package.proto".to_string()), - package: Some("diff".to_string()), message_type: vec![ DescriptorProto { name: Some("Request".to_string()), @@ -973,23 +992,37 @@ pub(crate) mod tests { .. FileDescriptorProto::default() }; let all_descriptors = &hashmap!{ + "service.proto".to_string() => &service, "request.proto".to_string() => &request, "response.proto".to_string() => &response, - "request_diff_package.proto".to_string() => &request_diff_package + "request_no_package.proto".to_string() => &request_no_package }; - // use default package name from the service descriptor - let result = find_message_descriptor("Request", None, &service, all_descriptors); - expect!(result.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"field")); - // explicitly provide package name - let result_explicit_pkg: Result = find_message_descriptor("Request", Some("diff"), &service, all_descriptors); - expect!(result_explicit_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"bool_field")); + let result_explicit_pkg: Result = find_message_descriptor("Request", Some("service"), all_descriptors); + expect!(result_explicit_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"field")); + + // no package provided means search descriptors without packages only + let result_no_pkg: Result = find_message_descriptor("Request", None, all_descriptors); + expect!(result_no_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"bool_field")); // message not found error - let result_err: Result = find_message_descriptor("Response", Some("diff"), &service, all_descriptors); + let result_err: Result = find_message_descriptor("Missing", Some("service"), all_descriptors); expect!(result_err.as_ref()).to(be_err()); - expect!(result_err.unwrap_err().to_string()).to( - be_equal_to("Did not find a message type 'Response' in any of the file descriptors for package 'diff'") - ); + expect!(result_err.unwrap_err().to_string() + .starts_with("Did not find a message type 'Missing' in any of the file descriptors")) + .to(be_true()); + + // file descriptor not found for package + let result_err_no_pkg: Result = find_message_descriptor("Request", Some("missing"), all_descriptors); + expect!(result_err_no_pkg.as_ref()).to(be_err()); + // "Did not find a file descriptor for package 'missing'" + expect!(result_err_no_pkg.unwrap_err().to_string()) + .to(be_equal_to("Did not find a file descriptor for a package 'missing'")); + + // no descriptors found without a package + let result_err_no_pkg: Result = find_message_descriptor("Request", None, &hashmap!{}); + expect!(result_err_no_pkg.as_ref()).to(be_err()); + expect!(result_err_no_pkg.unwrap_err().to_string()) + .to(be_equal_to("Did not find any file descriptors with no package specified")); } }