Send grpc metadata/auth with reflection requests

Closes https://feedback.yaak.app/p/send-metadata-during-grpc-reflection
This commit is contained in:
Gregory Schier
2025-05-11 07:20:57 -07:00
parent 5f8d99ba64
commit 035fe54df0
5 changed files with 124 additions and 71 deletions

View File

@@ -1,10 +1,14 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::error::Result;
use KeyAndValueRef::{Ascii, Binary}; use KeyAndValueRef::{Ascii, Binary};
use tauri::{Manager, Runtime, WebviewWindow};
use yaak_grpc::{KeyAndValueRef, MetadataMap}; use yaak_grpc::{KeyAndValueRef, MetadataMap};
use yaak_models::models::GrpcRequest;
use yaak_plugins::events::{CallHttpAuthenticationRequest, HttpHeader};
use yaak_plugins::manager::PluginManager;
pub fn metadata_to_map(metadata: MetadataMap) -> BTreeMap<String, String> { pub(crate) fn metadata_to_map(metadata: MetadataMap) -> BTreeMap<String, String> {
let mut entries = BTreeMap::new(); let mut entries = BTreeMap::new();
for r in metadata.iter() { for r in metadata.iter() {
match r { match r {
@@ -14,3 +18,48 @@ pub fn metadata_to_map(metadata: MetadataMap) -> BTreeMap<String, String> {
} }
entries entries
} }
pub(crate) async fn build_metadata<R: Runtime>(
window: &WebviewWindow<R>,
request: &GrpcRequest,
) -> Result<BTreeMap<String, String>> {
let plugin_manager = window.state::<PluginManager>();
let mut metadata = BTreeMap::new();
// Add the rest of metadata
for h in request.clone().metadata {
if h.name.is_empty() && h.value.is_empty() {
continue;
}
if !h.enabled {
continue;
}
metadata.insert(h.name, h.value);
}
if let Some(auth_name) = request.authentication_type.clone() {
let auth = request.authentication.clone();
let plugin_req = CallHttpAuthenticationRequest {
context_id: format!("{:x}", md5::compute(request.id.clone())),
values: serde_json::from_value(serde_json::to_value(&auth).unwrap()).unwrap(),
method: "POST".to_string(),
url: request.url.clone(),
headers: metadata
.iter()
.map(|(name, value)| HttpHeader {
name: name.to_string(),
value: value.to_string(),
})
.collect(),
};
let plugin_result =
plugin_manager.call_http_authentication(&window, &auth_name, plugin_req).await?;
for header in plugin_result.set_headers {
metadata.insert(header.name, header.value);
}
}
Ok(metadata)
}

View File

@@ -1,7 +1,7 @@
extern crate core; extern crate core;
use crate::encoding::read_response_body; use crate::encoding::read_response_body;
use crate::error::Error::GenericError; use crate::error::Error::GenericError;
use crate::grpc::metadata_to_map; use crate::grpc::{build_metadata, metadata_to_map};
use crate::http_request::send_http_request; use crate::http_request::send_http_request;
use crate::notifications::YaakNotifier; use crate::notifications::YaakNotifier;
use crate::render::{render_grpc_request, render_template}; use crate::render::{render_grpc_request, render_template};
@@ -38,9 +38,9 @@ use yaak_models::util::{
BatchUpsertResult, UpdateSource, get_workspace_export_resources, maybe_gen_id, maybe_gen_id_opt, BatchUpsertResult, UpdateSource, get_workspace_export_resources, maybe_gen_id, maybe_gen_id_opt,
}; };
use yaak_plugins::events::{ use yaak_plugins::events::{
BootResponse, CallHttpAuthenticationRequest, CallHttpRequestActionRequest, FilterResponse, BootResponse, CallHttpRequestActionRequest, FilterResponse,
GetHttpAuthenticationConfigResponse, GetHttpAuthenticationSummaryResponse, GetHttpAuthenticationConfigResponse, GetHttpAuthenticationSummaryResponse,
GetHttpRequestActionsResponse, GetTemplateFunctionsResponse, HttpHeader, InternalEvent, GetHttpRequestActionsResponse, GetTemplateFunctionsResponse, InternalEvent,
InternalEventPayload, JsonPrimitive, PluginWindowContext, RenderPurpose, InternalEventPayload, JsonPrimitive, PluginWindowContext, RenderPurpose,
}; };
use yaak_plugins::manager::PluginManager; use yaak_plugins::manager::PluginManager;
@@ -166,6 +166,7 @@ async fn cmd_grpc_reflect<R: Runtime>(
.await?; .await?;
let uri = safe_uri(&req.url); let uri = safe_uri(&req.url);
let metadata = build_metadata(&window, &req).await?;
Ok(grpc_handle Ok(grpc_handle
.lock() .lock()
@@ -174,6 +175,7 @@ async fn cmd_grpc_reflect<R: Runtime>(
&req.id, &req.id,
&uri, &uri,
&proto_files.iter().map(|p| PathBuf::from_str(p).unwrap()).collect(), &proto_files.iter().map(|p| PathBuf::from_str(p).unwrap()).collect(),
&metadata,
) )
.await .await
.map_err(|e| GenericError(e.to_string()))?) .map_err(|e| GenericError(e.to_string()))?)
@@ -186,7 +188,6 @@ async fn cmd_grpc_go<R: Runtime>(
proto_files: Vec<String>, proto_files: Vec<String>,
app_handle: AppHandle<R>, app_handle: AppHandle<R>,
window: WebviewWindow<R>, window: WebviewWindow<R>,
plugin_manager: State<'_, PluginManager>,
grpc_handle: State<'_, Mutex<GrpcHandle>>, grpc_handle: State<'_, Mutex<GrpcHandle>>,
) -> YaakResult<String> { ) -> YaakResult<String> {
let environment = match environment_id { let environment = match environment_id {
@@ -208,42 +209,7 @@ async fn cmd_grpc_go<R: Runtime>(
) )
.await?; .await?;
let mut metadata = BTreeMap::new(); let metadata = build_metadata(&window, &request).await?;
// Add the rest of metadata
for h in request.clone().metadata {
if h.name.is_empty() && h.value.is_empty() {
continue;
}
if !h.enabled {
continue;
}
metadata.insert(h.name, h.value);
}
if let Some(auth_name) = request.authentication_type.clone() {
let auth = request.authentication.clone();
let plugin_req = CallHttpAuthenticationRequest {
context_id: format!("{:x}", md5::compute(request_id.to_string())),
values: serde_json::from_value(serde_json::to_value(&auth).unwrap()).unwrap(),
method: "POST".to_string(),
url: request.url.clone(),
headers: metadata
.iter()
.map(|(name, value)| HttpHeader {
name: name.to_string(),
value: value.to_string(),
})
.collect(),
};
let plugin_result =
plugin_manager.call_http_authentication(&window, &auth_name, plugin_req).await?;
for header in plugin_result.set_headers {
metadata.insert(header.name, header.value);
}
}
let conn = app_handle.db().upsert_grpc_connection( let conn = app_handle.db().upsert_grpc_connection(
&GrpcConnection { &GrpcConnection {
@@ -291,6 +257,7 @@ async fn cmd_grpc_go<R: Runtime>(
&request.clone().id, &request.clone().id,
uri.as_str(), uri.as_str(),
&proto_files.iter().map(|p| PathBuf::from_str(p).unwrap()).collect(), &proto_files.iter().map(|p| PathBuf::from_str(p).unwrap()).collect(),
&metadata,
) )
.await; .await;
@@ -448,7 +415,7 @@ async fn cmd_grpc_go<R: Runtime>(
match (method_desc.is_client_streaming(), method_desc.is_server_streaming()) { match (method_desc.is_client_streaming(), method_desc.is_server_streaming()) {
(true, true) => ( (true, true) => (
Some( Some(
connection.streaming(&service, &method, in_msg_stream, metadata).await, connection.streaming(&service, &method, in_msg_stream, &metadata).await,
), ),
None, None,
), ),
@@ -456,16 +423,16 @@ async fn cmd_grpc_go<R: Runtime>(
None, None,
Some( Some(
connection connection
.client_streaming(&service, &method, in_msg_stream, metadata) .client_streaming(&service, &method, in_msg_stream, &metadata)
.await, .await,
), ),
), ),
(false, true) => ( (false, true) => (
Some(connection.server_streaming(&service, &method, &msg, metadata).await), Some(connection.server_streaming(&service, &method, &msg, &metadata).await),
None, None,
), ),
(false, false) => { (false, false) => {
(None, Some(connection.unary(&service, &method, &msg, metadata).await)) (None, Some(connection.unary(&service, &method, &msg, &metadata).await))
} }
}; };

View File

@@ -1,13 +1,15 @@
use crate::manager::decorate_req;
use crate::transport::get_transport; use crate::transport::get_transport;
use async_recursion::async_recursion; use async_recursion::async_recursion;
use hyper_rustls::HttpsConnector; use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client; use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use log::debug; use log::debug;
use std::collections::BTreeMap;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tonic::Request;
use tonic::body::BoxBody; use tonic::body::BoxBody;
use tonic::transport::Uri; use tonic::transport::Uri;
use tonic::Request;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest; use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse; use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
use tonic_reflection::pb::v1::{ use tonic_reflection::pb::v1::{
@@ -44,6 +46,7 @@ impl AutoReflectionClient {
pub async fn send_reflection_request( pub async fn send_reflection_request(
&mut self, &mut self,
message: MessageRequest, message: MessageRequest,
metadata: &BTreeMap<String, String>,
) -> Result<MessageResponse, String> { ) -> Result<MessageResponse, String> {
let reflection_request = ServerReflectionRequest { let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter host: "".into(), // Doesn't matter
@@ -51,7 +54,9 @@ impl AutoReflectionClient {
}; };
if self.use_v1alpha { if self.use_v1alpha {
let request = Request::new(tokio_stream::once(to_v1alpha_request(reflection_request))); let mut request = Request::new(tokio_stream::once(to_v1alpha_request(reflection_request)));
decorate_req(metadata, &mut request).map_err(|e| e.to_string())?;
self.client_v1alpha self.client_v1alpha
.server_reflection_info(request) .server_reflection_info(request)
.await .await
@@ -70,7 +75,9 @@ impl AutoReflectionClient {
.ok_or("No reflection response".to_string()) .ok_or("No reflection response".to_string())
.map(|resp| to_v1_msg_response(resp)) .map(|resp| to_v1_msg_response(resp))
} else { } else {
let request = Request::new(tokio_stream::once(reflection_request)); let mut request = Request::new(tokio_stream::once(reflection_request));
decorate_req(metadata, &mut request).map_err(|e| e.to_string())?;
let resp = self.client_v1.server_reflection_info(request).await; let resp = self.client_v1.server_reflection_info(request).await;
match resp { match resp {
Ok(r) => Ok(r), Ok(r) => Ok(r),
@@ -79,7 +86,7 @@ impl AutoReflectionClient {
// If v1 fails, change to v1alpha and try again // If v1 fails, change to v1alpha and try again
debug!("gRPC schema reflection falling back to v1alpha"); debug!("gRPC schema reflection falling back to v1alpha");
self.use_v1alpha = true; self.use_v1alpha = true;
return self.send_reflection_request(message).await; return self.send_reflection_request(message, metadata).await;
} }
_ => Err(e), _ => Err(e),
}, },

View File

@@ -69,7 +69,7 @@ impl GrpcConnection {
service: &str, service: &str,
method: &str, method: &str,
message: &str, message: &str,
metadata: BTreeMap<String, String>, metadata: &BTreeMap<String, String>,
) -> Result<Response<DynamicMessage>, StreamError> { ) -> Result<Response<DynamicMessage>, StreamError> {
let method = &self.method(&service, &method)?; let method = &self.method(&service, &method)?;
let input_message = method.input(); let input_message = method.input();
@@ -96,7 +96,7 @@ impl GrpcConnection {
service: &str, service: &str,
method: &str, method: &str,
stream: ReceiverStream<DynamicMessage>, stream: ReceiverStream<DynamicMessage>,
metadata: BTreeMap<String, String>, metadata: &BTreeMap<String, String>,
) -> Result<Response<Streaming<DynamicMessage>>, StreamError> { ) -> Result<Response<Streaming<DynamicMessage>>, StreamError> {
let method = &self.method(&service, &method)?; let method = &self.method(&service, &method)?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone()); let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
@@ -116,7 +116,7 @@ impl GrpcConnection {
service: &str, service: &str,
method: &str, method: &str,
stream: ReceiverStream<DynamicMessage>, stream: ReceiverStream<DynamicMessage>,
metadata: BTreeMap<String, String>, metadata: &BTreeMap<String, String>,
) -> Result<Response<DynamicMessage>, StreamError> { ) -> Result<Response<DynamicMessage>, StreamError> {
let method = &self.method(&service, &method)?; let method = &self.method(&service, &method)?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone()); let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
@@ -137,7 +137,7 @@ impl GrpcConnection {
service: &str, service: &str,
method: &str, method: &str,
message: &str, message: &str,
metadata: BTreeMap<String, String>, metadata: &BTreeMap<String, String>,
) -> Result<Response<Streaming<DynamicMessage>>, StreamError> { ) -> Result<Response<Streaming<DynamicMessage>>, StreamError> {
let method = &self.method(&service, &method)?; let method = &self.method(&service, &method)?;
let input_message = method.input(); let input_message = method.input();
@@ -180,10 +180,11 @@ impl GrpcHandle {
id: &str, id: &str,
uri: &str, uri: &str,
proto_files: &Vec<PathBuf>, proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
) -> Result<(), String> { ) -> Result<(), String> {
let pool = if proto_files.is_empty() { let pool = if proto_files.is_empty() {
let full_uri = uri_from_str(uri)?; let full_uri = uri_from_str(uri)?;
fill_pool_from_reflection(&full_uri).await fill_pool_from_reflection(&full_uri, metadata).await
} else { } else {
fill_pool_from_files(&self.app_handle, proto_files).await fill_pool_from_files(&self.app_handle, proto_files).await
}?; }?;
@@ -197,9 +198,10 @@ impl GrpcHandle {
id: &str, id: &str,
uri: &str, uri: &str,
proto_files: &Vec<PathBuf>, proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
) -> Result<Vec<ServiceDefinition>, String> { ) -> Result<Vec<ServiceDefinition>, String> {
// Ensure reflection is up-to-date // Ensure reflection is up-to-date
self.reflect(id, uri, proto_files).await?; self.reflect(id, uri, proto_files, metadata).await?;
let pool = self.get_pool(id, uri, proto_files).ok_or("Failed to get pool".to_string())?; let pool = self.get_pool(id, uri, proto_files).ok_or("Failed to get pool".to_string())?;
Ok(self.services_from_pool(&pool)) Ok(self.services_from_pool(&pool))
@@ -235,8 +237,9 @@ impl GrpcHandle {
id: &str, id: &str,
uri: &str, uri: &str,
proto_files: &Vec<PathBuf>, proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
) -> Result<GrpcConnection, String> { ) -> Result<GrpcConnection, String> {
self.reflect(id, uri, proto_files).await?; self.reflect(id, uri, proto_files, metadata).await?;
let pool = self.get_pool(id, uri, proto_files).ok_or("Failed to get pool")?; let pool = self.get_pool(id, uri, proto_files).ok_or("Failed to get pool")?;
let uri = uri_from_str(uri)?; let uri = uri_from_str(uri)?;
@@ -254,7 +257,10 @@ impl GrpcHandle {
} }
} }
fn decorate_req<T>(metadata: BTreeMap<String, String>, req: &mut Request<T>) -> Result<(), String> { pub(crate) fn decorate_req<T>(
metadata: &BTreeMap<String, String>,
req: &mut Request<T>,
) -> Result<(), String> {
for (k, v) in metadata { for (k, v) in metadata {
req.metadata_mut().insert( req.metadata_mut().insert(
MetadataKey::from_str(k.as_str()).map_err(|e| e.to_string())?, MetadataKey::from_str(k.as_str()).map_err(|e| e.to_string())?,

View File

@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::env::temp_dir; use std::env::temp_dir;
use std::ops::Deref; use std::ops::Deref;
use std::path::PathBuf; use std::path::PathBuf;
@@ -89,11 +90,14 @@ pub async fn fill_pool_from_files(
Ok(pool) Ok(pool)
} }
pub async fn fill_pool_from_reflection(uri: &Uri) -> Result<DescriptorPool, String> { pub async fn fill_pool_from_reflection(
uri: &Uri,
metadata: &BTreeMap<String, String>,
) -> Result<DescriptorPool, String> {
let mut pool = DescriptorPool::new(); let mut pool = DescriptorPool::new();
let mut client = AutoReflectionClient::new(uri); let mut client = AutoReflectionClient::new(uri);
for service in list_services(&mut client).await? { for service in list_services(&mut client, metadata).await? {
if service == "grpc.reflection.v1alpha.ServerReflection" { if service == "grpc.reflection.v1alpha.ServerReflection" {
continue; continue;
} }
@@ -101,14 +105,18 @@ pub async fn fill_pool_from_reflection(uri: &Uri) -> Result<DescriptorPool, Stri
// TODO: update reflection client to use v1 // TODO: update reflection client to use v1
continue; continue;
} }
file_descriptor_set_from_service_name(&service, &mut pool, &mut client).await; file_descriptor_set_from_service_name(&service, &mut pool, &mut client, metadata).await;
} }
Ok(pool) Ok(pool)
} }
async fn list_services(client: &mut AutoReflectionClient) -> Result<Vec<String>, String> { async fn list_services(
let response = client.send_reflection_request(MessageRequest::ListServices("".into())).await?; client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) -> Result<Vec<String>, String> {
let response =
client.send_reflection_request(MessageRequest::ListServices("".into()), metadata).await?;
let list_services_response = match response { let list_services_response = match response {
MessageResponse::ListServicesResponse(resp) => resp, MessageResponse::ListServicesResponse(resp) => resp,
@@ -122,9 +130,13 @@ async fn file_descriptor_set_from_service_name(
service_name: &str, service_name: &str,
pool: &mut DescriptorPool, pool: &mut DescriptorPool,
client: &mut AutoReflectionClient, client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) { ) {
let response = match client let response = match client
.send_reflection_request(MessageRequest::FileContainingSymbol(service_name.into())) .send_reflection_request(
MessageRequest::FileContainingSymbol(service_name.into()),
metadata,
)
.await .await
{ {
Ok(resp) => resp, Ok(resp) => resp,
@@ -139,8 +151,13 @@ async fn file_descriptor_set_from_service_name(
_ => panic!("Expected a FileDescriptorResponse variant"), _ => panic!("Expected a FileDescriptorResponse variant"),
}; };
add_file_descriptors_to_pool(file_descriptor_response.file_descriptor_proto, pool, client) add_file_descriptors_to_pool(
.await; file_descriptor_response.file_descriptor_proto,
pool,
client,
metadata,
)
.await;
} }
#[async_recursion] #[async_recursion]
@@ -148,6 +165,7 @@ async fn add_file_descriptors_to_pool(
fds: Vec<Vec<u8>>, fds: Vec<Vec<u8>>,
pool: &mut DescriptorPool, pool: &mut DescriptorPool,
client: &mut AutoReflectionClient, client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) { ) {
let mut topo_sort = topology::SimpleTopoSort::new(); let mut topo_sort = topology::SimpleTopoSort::new();
let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len()); let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len());
@@ -165,7 +183,7 @@ async fn add_file_descriptors_to_pool(
if let Some(fdp) = fd_mapping.remove(&node) { if let Some(fdp) = fd_mapping.remove(&node) {
pool.add_file_descriptor_proto(fdp).expect("add file descriptor proto"); pool.add_file_descriptor_proto(fdp).expect("add file descriptor proto");
} else { } else {
file_descriptor_set_by_filename(node.as_str(), pool, client).await; file_descriptor_set_by_filename(node.as_str(), pool, client, metadata).await;
} }
} }
Err(_) => panic!("proto file got cycle!"), Err(_) => panic!("proto file got cycle!"),
@@ -177,6 +195,7 @@ async fn file_descriptor_set_by_filename(
filename: &str, filename: &str,
pool: &mut DescriptorPool, pool: &mut DescriptorPool,
client: &mut AutoReflectionClient, client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) { ) {
// We already fetched this file // We already fetched this file
if let Some(_) = pool.get_file_by_name(filename) { if let Some(_) = pool.get_file_by_name(filename) {
@@ -184,7 +203,7 @@ async fn file_descriptor_set_by_filename(
} }
let msg = MessageRequest::FileByFilename(filename.into()); let msg = MessageRequest::FileByFilename(filename.into());
let response = client.send_reflection_request(msg).await; let response = client.send_reflection_request(msg, metadata).await;
let file_descriptor_response = match response { let file_descriptor_response = match response {
Ok(MessageResponse::FileDescriptorResponse(resp)) => resp, Ok(MessageResponse::FileDescriptorResponse(resp)) => resp,
Ok(_) => { Ok(_) => {
@@ -196,8 +215,13 @@ async fn file_descriptor_set_by_filename(
} }
}; };
add_file_descriptors_to_pool(file_descriptor_response.file_descriptor_proto, pool, client) add_file_descriptors_to_pool(
.await; file_descriptor_response.file_descriptor_proto,
pool,
client,
metadata,
)
.await;
} }
pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery { pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery {