diff --git a/src/protobuf.rs b/src/protobuf.rs index 5d7fe2b..408ddcb 100644 --- a/src/protobuf.rs +++ b/src/protobuf.rs @@ -1652,8 +1652,8 @@ pub(crate) mod tests { method: vec![ MethodDescriptorProto { name: Some("call".to_string()), - input_type: Some("StringValue".to_string()), - output_type: Some("test_message".to_string()), + input_type: Some("test_package.StringValue".to_string()), + output_type: Some("test_package.test_message".to_string()), options: None, client_streaming: None, server_streaming: None @@ -1747,8 +1747,8 @@ pub(crate) mod tests { method: vec![ MethodDescriptorProto { name: Some("call".to_string()), - input_type: Some("StringValue".to_string()), - output_type: Some("test_message".to_string()), + input_type: Some("test_package.StringValue".to_string()), + output_type: Some("test_package.test_message".to_string()), options: None, client_streaming: None, server_streaming: None @@ -2709,39 +2709,4 @@ pub(crate) mod tests { ] )); } - - #[test] - fn find_message_descriptor_test() { - let descriptors = "CpAEChdpbXBvcnRlZC9pbXBvcnRlZC5wcm90bxIIaW1wb3J0ZWQiOQoJUmVjdGFuZ2x\ - lEhQKBXdpZHRoGAEgASgFUgV3aWR0aBIWCgZsZW5ndGgYAiABKAVSBmxlbmd0aCJIChhSZWN0YW5nbGVMb2NhdGlvblJ\ - lcXVlc3QSFAoFd2lkdGgYASABKAVSBXdpZHRoEhYKBmxlbmd0aBgCIAEoBVIGbGVuZ3RoIkgKGVJlY3RhbmdsZUxvY2F0\ - aW9uUmVzcG9uc2USKwoIbG9jYXRpb24YASABKAsyDy5pbXBvcnRlZC5Qb2ludFIIbG9jYXRpb24iQQoFUG9pbnQSGgoIb\ - GF0aXR1ZGUYASABKAVSCGxhdGl0dWRlEhwKCWxvbmdpdHVkZRgCIAEoBVIJbG9uZ2l0dWRlMmUKCEltcG9ydGVkElkKDE\ - dldFJlY3RhbmdsZRIiLmltcG9ydGVkLlJlY3RhbmdsZUxvY2F0aW9uUmVxdWVzdBojLmltcG9ydGVkLlJlY3RhbmdsZUxv\ - Y2F0aW9uUmVzcG9uc2UiAEJqChlpby5ncnBjLmV4YW1wbGVzLmltcG9ydGVkQg1JbXBvcnRlZFByb3RvUAFaPGdpdGh1Y\ - i5jb20vcGFjdC1mb3VuZGF0aW9uL3BhY3QtZ28vdjIvZXhhbXBsZXMvZ3JwYy9pbXBvcnRlZGIGcHJvdG8zCooECg1wcm\ - ltYXJ5LnByb3RvEgdwcmltYXJ5GhdpbXBvcnRlZC9pbXBvcnRlZC5wcm90byJNCglSZWN0YW5nbGUSHwoCbG8YASABKAs\ - yDy5pbXBvcnRlZC5Qb2ludFICbG8SHwoCaGkYAiABKAsyDy5pbXBvcnRlZC5Qb2ludFICaGkiZAoYUmVjdGFuZ2xlTG9j\ - YXRpb25SZXF1ZXN0EgwKAXgYASABKAVSAXgSDAoBeRgCIAEoBVIBeRIUCgV3aWR0aBgDIAEoBVIFd2lkdGgSFgoGbGVuZ\ - 3RoGAQgASgFUgZsZW5ndGgiTQoZUmVjdGFuZ2xlTG9jYXRpb25SZXNwb25zZRIwCglyZWN0YW5nbGUYASABKAsyEi5wcml\ - tYXJ5LlJlY3RhbmdsZVIJcmVjdGFuZ2xlMmIKB1ByaW1hcnkSVwoMR2V0UmVjdGFuZ2xlEiEucHJpbWFyeS5SZWN0YW5nb\ - GVMb2NhdGlvblJlcXVlc3QaIi5wcmltYXJ5LlJlY3RhbmdsZUxvY2F0aW9uUmVzcG9uc2UiAEJnChhpby5ncnBjLmV4YW1\ - wbGVzLnByaW1hcnlCDFByaW1hcnlQcm90b1ABWjtnaXRodWIuY29tL3BhY3QtZm91bmRhdGlvbi9wYWN0LWdvL3YyL2V4Y\ - W1wbGVzL2dycGMvcHJpbWFyeWIGcHJvdG8z"; - let decoded = BASE64.decode(descriptors).unwrap(); - let bytes = Bytes::copy_from_slice(decoded.as_slice()); - let fds = FileDescriptorSet::decode(bytes).unwrap(); - let all: HashMap = fds.file - .iter().map(|des| (des.name.clone().unwrap_or_default(), des)) - .collect(); - - let result = super::find_message_descriptor("RectangleLocationRequest", None, &all).unwrap(); - expect!(result.field.len()).to(be_equal_to(2)); - let result = super::find_message_descriptor("RectangleLocationRequest", Some("primary"), &all).unwrap(); - expect!(result.field.len()).to(be_equal_to(4)); - let result = super::find_message_descriptor("RectangleLocationRequest", Some(".primary"), &all).unwrap(); - expect!(result.field.len()).to(be_equal_to(4)); - let result = super::find_message_descriptor("RectangleLocationRequest", Some("imported"), &all).unwrap(); - expect!(result.field.len()).to(be_equal_to(2)); - } } diff --git a/src/utils.rs b/src/utils.rs index e94135f..4d91968 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -110,10 +110,10 @@ pub fn find_message_descriptor( }) } -fn find_all_file_descriptors_for_package<'a>( +fn find_all_file_descriptors_for_package( package: &str, - all_descriptors: &'a HashMap, -) -> anyhow::Result> { + all_descriptors: &HashMap, +) -> anyhow::Result> { let package = if package.starts_with('.') { &package[1..] } else { @@ -135,9 +135,9 @@ fn find_all_file_descriptors_for_package<'a>( } } -fn find_all_file_descriptors_with_no_package<'a>( - all_descriptors: &'a HashMap -) -> anyhow::Result> { +fn find_all_file_descriptors_with_no_package( + all_descriptors: &HashMap +) -> anyhow::Result> { let found = filter_file_descriptors(all_descriptors, |d| d.package.is_none()); if found.is_empty() { Err(anyhow!("Did not find any file descriptors with no package specified")) @@ -147,15 +147,18 @@ fn find_all_file_descriptors_with_no_package<'a>( } } -fn filter_file_descriptors<'a, F>(all_descriptors: &'a HashMap, filter: F -) -> Vec<&'a FileDescriptorProto> +fn filter_file_descriptors( + all_descriptors: &HashMap, + filter: F +) -> Vec where - F: FnMut(&&&FileDescriptorProto) -> bool, + F: FnMut(&&FileDescriptorProto) -> bool { - all_descriptors.values() - .filter(filter) - .cloned() - .collect::>() + all_descriptors.values() + .map(|fd| *fd) // Convert &&FileDescriptorProto -> &FileDescriptorProto + .filter(filter) + .cloned() + .collect() } /// If the field is a map field. A field will be a map field if it is a repeated field, the field @@ -584,8 +587,7 @@ pub(crate) fn prost_string>(s: S) -> Value { #[cfg(test)] pub(crate) mod tests { - use anyhow::anyhow; -use bytes::Bytes; + use bytes::Bytes; use expectest::prelude::*; use maplit::hashmap; use prost::Message; @@ -974,7 +976,7 @@ use bytes::Bytes; .. FileDescriptorProto::default() }; let request_no_package = FileDescriptorProto { - name: Some("request_diff_package.proto".to_string()), + name: Some("request_no_package.proto".to_string()), message_type: vec![ DescriptorProto { name: Some("Request".to_string()), @@ -998,29 +1000,32 @@ use bytes::Bytes; "request_no_package.proto".to_string() => &request_no_package }; // explicitly provide package name - let result_explicit_pkg: Result = find_message_descriptor("Request", Some("service"), all_descriptors); + let result_explicit_pkg = find_message_descriptor("Request", Some("service"), all_descriptors); expect!(result_explicit_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"field")); + // same but with a dot + let result_explicit_pkg_dot = find_message_descriptor("Request", Some(".service"), all_descriptors); + expect!(result_explicit_pkg_dot.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"field")); // no package provided means search descriptors without packages only - let result_no_pkg: Result = find_message_descriptor("Request", None, all_descriptors); + let result_no_pkg = find_message_descriptor("Request", None, all_descriptors); expect!(result_no_pkg.as_ref().unwrap().field[0].name.as_ref()).to(be_some().value(&"bool_field")); // message not found error - let result_err: Result = find_message_descriptor("Missing", Some("service"), all_descriptors); + let result_err = find_message_descriptor("Missing", Some("service"), all_descriptors); expect!(result_err.as_ref()).to(be_err()); expect!(result_err.unwrap_err().to_string() .starts_with("Did not find a message type 'Missing' in any of the file descriptors")) .to(be_true()); // file descriptor not found for package - let result_err_no_pkg: Result = find_message_descriptor("Request", Some("missing"), all_descriptors); + let result_err_no_pkg = find_message_descriptor("Request", Some("missing"), all_descriptors); expect!(result_err_no_pkg.as_ref()).to(be_err()); // "Did not find a file descriptor for package 'missing'" expect!(result_err_no_pkg.unwrap_err().to_string()) .to(be_equal_to("Did not find a file descriptor for a package 'missing'")); // no descriptors found without a package - let result_err_no_pkg: Result = find_message_descriptor("Request", None, &hashmap!{}); + let result_err_no_pkg = find_message_descriptor("Request", None, &hashmap!{}); expect!(result_err_no_pkg.as_ref()).to(be_err()); expect!(result_err_no_pkg.unwrap_err().to_string()) .to(be_equal_to("Did not find any file descriptors with no package specified"));