Skip to content

Commit

Permalink
Merge pull request pactflow#60 from ermul/package-namespace-fixes
Browse files Browse the repository at this point in the history
fix: respect package namespaces when resolving message descriptors
  • Loading branch information
rholshausen authored May 9, 2024
2 parents 0ad6afd + 0f47677 commit aa28fb4
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 95 deletions.
14 changes: 9 additions & 5 deletions integrated_tests/imported_message/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::configure().include_file("mod.rs").compile(
&["primary/primary.proto", "imported/imported.proto"],
&["."],
)?;
Ok(())
tonic_build::configure().include_file("mod.rs").compile(
&[
"primary/primary.proto",
"imported/imported.proto",
"zimported/zimported.proto",
],
&["."],
)?;
Ok(())
}
4 changes: 2 additions & 2 deletions integrated_tests/imported_message/imported/imported.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ message Rectangle {
// Request message for GetRectangle method. This message has different fields,
// but the same name as a message defined in primary.proto
message RectangleLocationRequest {
int32 width = 1;
int32 length = 2;
int32 a = 1;
int32 b = 2;
}

// Response message for GetRectangle method. This message has different fields,
Expand Down
1 change: 1 addition & 0 deletions integrated_tests/imported_message/primary/primary.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ option java_package = "io.grpc.examples.primary";
option java_outer_classname = "PrimaryProto";

import "imported/imported.proto";
import "zimported/zimported.proto";

package primary;

Expand Down
96 changes: 48 additions & 48 deletions integrated_tests/imported_message/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,41 @@ tonic::include_proto!("mod");

#[cfg(test)]
mod tests {
use std::path::Path;
use std::path::Path;

use crate::primary::primary_client::PrimaryClient;
use pact_consumer::mock_server::StartMockServerAsync;
use pact_consumer::prelude::*;
use serde_json::json;
use tonic::IntoRequest;
use tracing::info;
use crate::primary::primary_client::PrimaryClient;
use pact_consumer::mock_server::StartMockServerAsync;
use pact_consumer::prelude::*;
use serde_json::json;
use tonic::IntoRequest;
use tracing::info;

use super::*;
use super::*;

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_proto_client() {
let _ = env_logger::builder().is_test(true).try_init();
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_proto_client() {
let _ = env_logger::builder().is_test(true).try_init();

let mut pact_builder = PactBuilderAsync::new_v4("grpc-consumer-rust", "imported_message");
let mock_server = pact_builder
.using_plugin("protobuf", None)
.await
.synchronous_message_interaction(
"package namespace not respected",
|mut i| async move {
let proto_file = Path::new("primary/primary.proto")
.canonicalize()
.unwrap()
.to_string_lossy()
.to_string();
let proto_include = Path::new(".")
.canonicalize()
.unwrap()
.to_string_lossy()
.to_string();
info!("proto_file: {}", proto_file);
info!("proto_include: {}", proto_include);
i.contents_from(json!({
let mut pact_builder = PactBuilderAsync::new_v4("grpc-consumer-rust", "imported_message");
let mock_server = pact_builder
.using_plugin("protobuf", None)
.await
.synchronous_message_interaction(
"package namespace not respected",
|mut i| async move {
let proto_file = Path::new("primary/primary.proto")
.canonicalize()
.unwrap()
.to_string_lossy()
.to_string();
let proto_include = Path::new(".")
.canonicalize()
.unwrap()
.to_string_lossy()
.to_string();
info!("proto_file: {}", proto_file);
info!("proto_include: {}", proto_include);
i.contents_from(json!({
"pact:proto": proto_file,
"pact:proto-service": "Primary/GetRectangle",
"pact:content-type": "application/protobuf",
Expand All @@ -62,25 +62,25 @@ mod tests {
}
}
}))
.await;
i
},
)
.await
.start_mock_server_async(Some("protobuf/transport/grpc"))
.await;
i
},
)
.await
.start_mock_server_async(Some("protobuf/transport/grpc"))
.await;

let url = mock_server.url();
let url = mock_server.url();

let mut client = PrimaryClient::connect(url.to_string()).await.unwrap();
let request_message = primary::RectangleLocationRequest {
x: 180,
y: 200,
width: 10,
length: 20,
};
let mut client = PrimaryClient::connect(url.to_string()).await.unwrap();
let request_message = primary::RectangleLocationRequest {
x: 180,
y: 200,
width: 10,
length: 20,
};

let response = client.get_rectangle(request_message.into_request()).await;
let _response_message = response.unwrap();
}
let response = client.get_rectangle(request_message.into_request()).await;
let _response_message = response.unwrap();
}
}
39 changes: 39 additions & 0 deletions integrated_tests/imported_message/zimported/zimported.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

syntax = "proto3";

option go_package = "github.com/pact-foundation/pact-go/v2/examples/grpc/zimported";
option java_multiple_files = true;
option java_package = "io.grpc.examples.zimported";
option java_outer_classname = "ImportedProto";

package zimported;

service ZImported {
rpc GetRectangle(RectangleLocationRequest) returns (RectangleLocationResponse) {}
}

message Rectangle {
// The width of the rectangle.
int32 zwidth = 1;

// The length of the rectangle.
int32 zlength = 2;
}

// Request message for GetRectangle method. This message has different fields,
// but the same name as a message defined in primary.proto
message RectangleLocationRequest {
int32 zx = 1;
int32 zb = 2;
}

// Response message for GetRectangle method. This message has different fields,
// but the same name as a message defined in primary.proto
message RectangleLocationResponse {
Point zlocation = 1;
}

message Point {
int32 zlatitude = 1;
int32 zlongitude = 2;
}
45 changes: 25 additions & 20 deletions src/mock_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use pact_models::plugins::PluginData;
use pact_models::prelude::v4::V4Pact;
use pact_models::v4::sync_message::SynchronousMessage;
use prost::Message;
use prost_types::{FileDescriptorSet, MethodDescriptorProto};
use prost_types::{FileDescriptorProto, FileDescriptorSet, MethodDescriptorProto};
use serde_json::{json, Value};
use tokio::net::TcpListener;
use tokio::runtime::Handle;
Expand All @@ -42,7 +42,7 @@ use crate::dynamic_message::PactCodec;
use crate::metadata::MetadataMatchResult;
use crate::mock_service::MockService;
use crate::tcp::TcpIncoming;
use crate::utils::{find_message_type_by_name, last_name};
use crate::utils::{last_name, split_name, find_message_descriptor};

lazy_static! {
pub static ref MOCK_SERVER_STATE: Mutex<HashMap<String, (Sender<()>, HashMap<String, (usize, Vec<(BodyMatchResult, MetadataMatchResult)>)>)>> = Mutex::new(hashmap!{});
Expand All @@ -54,7 +54,7 @@ pub struct GrpcMockServer {
pact: V4Pact,
plugin_config: PluginData,
descriptors: HashMap<String, FileDescriptorSet>,
routes: HashMap<String, (FileDescriptorSet, MethodDescriptorProto, SynchronousMessage)>,
routes: HashMap<String, (FileDescriptorSet, FileDescriptorProto, MethodDescriptorProto, SynchronousMessage)>,
/// Server key for this mock server
pub server_key: String,
/// test context pass in from the test framework
Expand Down Expand Up @@ -103,14 +103,11 @@ impl GrpcMockServer
if let Some(descriptors) = self.descriptors.get(json_to_string(key).as_str()) {
if let Some(service) = c.get("service") {
if let Some((service_name, method_name)) = json_to_string(service).split_once('/') {
descriptors.file.iter().filter_map(|d| {
d.service.iter().find(|s| s.name.clone().unwrap_or_default() == service_name)
}).next()
.and_then(|d| {
d.method.iter()
.find(|m| m.name.clone().unwrap_or_default() == method_name)
.map(|m| (format!("{service_name}/{method_name}"), (descriptors.clone(), m.clone(), i.clone())))
})
return descriptors.file.iter().find_map(|fd| fd.service.iter().find(|s| s.name.clone().unwrap_or_default() == service_name).map( |s| s.method.iter().
find(|m| m.name.clone().unwrap_or_default() == method_name).
map(|m| (format!("{service_name}/{method_name}"), (descriptors.clone(), fd.clone(), m.clone(), i.clone())))
)
).unwrap();
} else {
// protobuf service was not properly formed <SERViCE>/<METHOD>
None
Expand Down Expand Up @@ -244,16 +241,24 @@ impl Service<Request<hyper::Body>> for GrpcMockServer {
if let Some((service, method)) = request_path[1..].split_once('/') {
let service_name = last_name(service);
let lookup = format!("{service_name}/{method}");
if let Some((file, method_descriptor, message)) = routes.get(lookup.as_str()) {
if let Some((file, file_descriptor, method_descriptor, message)) = routes.get(lookup.as_str()) {
trace!(message = message.description.as_str(), "Found route for service call");

let input_message_name = method_descriptor.input_type.clone().unwrap_or_default();
let input_message = find_message_type_by_name(last_name(input_message_name.as_str()), file);
let output_message_name = method_descriptor.output_type.clone().unwrap_or_default();
let output_message = find_message_type_by_name(last_name(output_message_name.as_str()), file);

if let Ok((input_message, _)) = input_message {
if let Ok((output_message, _)) = output_message {
let file_descriptors: HashMap<String, &FileDescriptorProto> = file.file.iter().map(
|des| (des.name.clone().unwrap_or_default(), des)).collect();
let input_name = method_descriptor.input_type.as_ref().expect(format!(
"Input message name is empty for service {}/{}", service_name, method).as_str());
let (input_message_name, input_package_name) = split_name(input_name);
let input_message = find_message_descriptor(
input_message_name, input_package_name, file_descriptor, &file_descriptors);

let output_name = method_descriptor.output_type.as_ref().expect(format!(
"Output message name is empty for service {}/{}", service_name, method).as_str());
let (output_message_name, output_package_name) = split_name(output_name);
let output_message = find_message_descriptor(
output_message_name, output_package_name, file_descriptor, &file_descriptors);

if let Ok(input_message) = input_message {
if let Ok(output_message) = output_message {
let codec = PactCodec::new(file, &input_message, &output_message, message);
let mock_service = MockService::new(file, service_name,
method_descriptor, &input_message, &output_message, message, server_key.as_str(),
Expand Down
22 changes: 3 additions & 19 deletions src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ use crate::protoc::Protoc;
use crate::utils::{
find_enum_value_by_name,
find_enum_value_by_name_in_message,
find_message_type_in_file_descriptors,
find_message_with_package_in_file_descriptors,
find_message_descriptor,
find_nested_type,
is_map_field,
is_repeated_field,
last_name,
prost_string,
split_name
};
Expand Down Expand Up @@ -284,20 +282,6 @@ fn request_part(
}
}

// Search for a message by name, first in the current file descriptor, then in all descriptors.
fn find_message_descriptor(
message_name: &str,
package: Option<&str>,
file_descriptor: &FileDescriptorProto,
all_descriptors: &HashMap<String, &FileDescriptorProto>
) -> anyhow::Result<DescriptorProto> {
if let Some(package) = package {
find_message_with_package_in_file_descriptors(message_name, package, file_descriptor, all_descriptors)
} else {
find_message_type_in_file_descriptors(message_name, file_descriptor, all_descriptors)
}
}

/// Configure the interaction for a single Protobuf message
fn configure_protobuf_message(
message_name: &str,
Expand Down Expand Up @@ -650,9 +634,9 @@ fn build_single_embedded_field_value(
Ok(None)
} else if let Value::Object(config) = value {
debug!("Configuring the message from config {:?}", config);
let message_name = last_name(type_name.as_str());
let (message_name, package_name) = split_name(type_name.as_str());
let embedded_type = find_nested_type(&message_builder.descriptor, field_descriptor)
.or_else(|| find_message_type_in_file_descriptors(message_name, &message_builder.file_descriptor, all_descriptors).ok())
.or_else(|| find_message_descriptor(message_name, package_name, &message_builder.file_descriptor, all_descriptors).ok())
.ok_or_else(|| anyhow!("Did not find message '{}' in the current message or in the file descriptors", type_name))?;
let mut embedded_builder = MessageBuilder::new(&embedded_type, message_name, &message_builder.file_descriptor);

Expand Down
20 changes: 19 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ pub fn split_name(name: &str) -> (&str, Option<&str>) {
if package.is_empty() {
(name, None)
} else {
(name, Some(package))
if let Some(trimmed) = package.strip_prefix(".") {
(name, Some(trimmed))
} else {
(name, Some(package))
}
}
})
.unwrap_or_else(|| (name, None))
Expand Down Expand Up @@ -93,6 +97,20 @@ pub fn find_message_type_in_file_descriptors(
})
}

// Search for a message by name, first in the current file descriptor, then in all descriptors.
pub fn find_message_descriptor(
message_name: &str,
package: Option<&str>,
file_descriptor: &FileDescriptorProto,
all_descriptors: &HashMap<String, &FileDescriptorProto>
) -> anyhow::Result<DescriptorProto> {
if let Some(package) = package {
find_message_with_package_in_file_descriptors(message_name, package, file_descriptor, all_descriptors)
} else {
find_message_type_in_file_descriptors(message_name, file_descriptor, all_descriptors)
}
}

/// Search for a message by type name and package in the file descriptor, and if not found,
/// search in all the descriptors. This will first check with the package name, and if nothing is
/// found, will then fall back to just using the message name
Expand Down

0 comments on commit aa28fb4

Please sign in to comment.