Fix gRPC Any response reflection (#451)

This commit is contained in:
Gregory Schier
2026-05-06 07:42:35 -07:00
committed by GitHub
parent 4c15a49f8f
commit a200410697
6 changed files with 171 additions and 10 deletions

View File

@@ -37,7 +37,7 @@ pub struct MethodDefinition {
static SERIALIZE_OPTIONS: &'static SerializeOptions =
&SerializeOptions::new().skip_default_fields(false).stringify_64_bit_integers(false);
pub fn serialize_message(msg: &DynamicMessage) -> Result<String, String> {
pub(crate) fn serialize_dynamic_message_json(msg: &DynamicMessage) -> Result<String, String> {
let mut buf = Vec::new();
let mut se = serde_json::Serializer::pretty(&mut buf);
msg.serialize_with_options(&mut se, SERIALIZE_OPTIONS).map_err(|e| e.to_string())?;

View File

@@ -2,7 +2,8 @@ use crate::codec::DynamicCodec;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::reflection::{
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path, reflect_types_for_message,
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path,
reflect_types_for_dynamic_message, reflect_types_for_message,
};
use crate::transport::get_transport;
use crate::{MethodDefinition, ServiceDefinition, json_schema};
@@ -11,8 +12,11 @@ use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use log::{info, warn};
pub use prost_reflect::DynamicMessage;
use prost_reflect::ReflectMessage;
use prost_reflect::prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
use serde_json::Deserializer;
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::error::Error;
use std::fmt;
@@ -115,6 +119,38 @@ impl GrpcConnection {
Ok(client.unary(req, path, codec).await?)
}
pub async fn serialize_message(
&self,
message: &DynamicMessage,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<String> {
let message = if self.use_reflection {
reflect_types_for_dynamic_message(
self.pool.clone(),
&self.uri,
message,
metadata,
client_cert,
)
.await?;
let message_name = message.descriptor().full_name().to_string();
let message_desc = {
let pool = self.pool.read().await;
pool.get_message_by_name(&message_name)
.ok_or(GenericError(format!("Failed to find message {message_name}")))?
};
let mut message_with_updated_pool = DynamicMessage::new(message_desc);
message_with_updated_pool.merge(message.encode_to_vec().as_slice())?;
Cow::Owned(message_with_updated_pool)
} else {
Cow::Borrowed(message)
};
crate::serialize_dynamic_message_json(message.as_ref()).map_err(GenericError)
}
pub async fn streaming<F>(
&self,
service: &str,

View File

@@ -7,7 +7,7 @@ use anyhow::anyhow;
use async_recursion::async_recursion;
use log::{debug, info, warn};
use prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor};
use prost_reflect::{DescriptorPool, DynamicMessage, MethodDescriptor, ReflectMessage, Value};
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use std::collections::{BTreeMap, HashSet};
use std::env::temp_dir;
@@ -233,6 +233,83 @@ pub(crate) async fn reflect_types_for_message(
Ok(())
}
pub(crate) async fn reflect_types_for_dynamic_message(
pool: Arc<RwLock<DescriptorPool>>,
uri: &Uri,
message: &DynamicMessage,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<()> {
let mut extra_types = HashSet::new();
collect_any_types_from_dynamic_message(message, &mut extra_types);
if extra_types.is_empty() {
return Ok(());
}
let mut client = AutoReflectionClient::new(uri, false, client_cert)?;
for extra_type in extra_types {
{
let guard = pool.read().await;
if guard.get_message_by_name(&extra_type).is_some() {
continue;
}
}
info!("Adding response file descriptor for {:?} from reflection", extra_type);
let req = MessageRequest::FileContainingSymbol(extra_type.clone().into());
let resp = match client.send_reflection_request(req, metadata).await {
Ok(r) => r,
Err(e) => {
return Err(GenericError(format!(
"Error sending reflection request for response @type \"{extra_type}\": {e:?}",
)));
}
};
let files = match resp {
MessageResponse::FileDescriptorResponse(resp) => resp.file_descriptor_proto,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
{
let mut guard = pool.write().await;
add_file_descriptors_to_pool(files, &mut *guard, &mut client, metadata).await;
}
}
Ok(())
}
fn collect_any_types_from_dynamic_message(message: &DynamicMessage, out: &mut HashSet<String>) {
if message.descriptor().full_name() == "google.protobuf.Any" {
if let Some(Value::String(type_url)) = message.get_field_by_name("type_url").as_deref() {
if let Some(full_name) = type_url.rsplit_once('/').map(|(_, name)| name) {
out.insert(full_name.to_string());
}
}
}
for (_, value) in message.fields() {
collect_any_types_from_value(value, out);
}
}
fn collect_any_types_from_value(value: &Value, out: &mut HashSet<String>) {
match value {
Value::Message(message) => collect_any_types_from_dynamic_message(message, out),
Value::List(values) => {
for value in values {
collect_any_types_from_value(value, out);
}
}
Value::Map(values) => {
for value in values.values() {
collect_any_types_from_value(value, out);
}
}
_ => {}
}
}
#[async_recursion]
pub(crate) async fn add_file_descriptors_to_pool(
fds: Vec<Vec<u8>>,