Auth plugins (#155)

This commit is contained in:
Gregory Schier
2025-01-17 05:53:03 -08:00
committed by GitHub
parent e21df98a30
commit bd322162c8
56 changed files with 5468 additions and 1474 deletions

View File

@@ -0,0 +1,172 @@
use crate::transport::get_transport;
use async_recursion::async_recursion;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use log::debug;
use tokio_stream::StreamExt;
use tonic::body::BoxBody;
use tonic::transport::Uri;
use tonic::Request;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
use tonic_reflection::pb::v1::{
ErrorResponse, ExtensionNumberResponse, ListServiceResponse, ServerReflectionRequest,
ServiceResponse,
};
use tonic_reflection::pb::v1::{ExtensionRequest, FileDescriptorResponse};
use tonic_reflection::pb::{v1, v1alpha};
pub struct AutoReflectionClient<T = Client<HttpsConnector<HttpConnector>, BoxBody>> {
use_v1alpha: bool,
client_v1: v1::server_reflection_client::ServerReflectionClient<T>,
client_v1alpha: v1alpha::server_reflection_client::ServerReflectionClient<T>,
}
impl AutoReflectionClient {
pub fn new(uri: &Uri) -> Self {
let client_v1 = v1::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(),
uri.clone(),
);
let client_v1alpha = v1alpha::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(),
uri.clone(),
);
AutoReflectionClient {
use_v1alpha: false,
client_v1,
client_v1alpha,
}
}
#[async_recursion]
pub async fn send_reflection_request(
&mut self,
message: MessageRequest,
) -> Result<MessageResponse, String> {
let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter
message_request: Some(message.clone()),
};
if self.use_v1alpha {
let request = Request::new(tokio_stream::once(to_v1alpha_request(reflection_request)));
self.client_v1alpha
.server_reflection_info(request)
.await
.map_err(|e| match e.code() {
tonic::Code::Unavailable => "Failed to connect to endpoint".to_string(),
tonic::Code::Unauthenticated => "Authentication failed".to_string(),
tonic::Code::DeadlineExceeded => "Deadline exceeded".to_string(),
_ => e.to_string(),
})?
.into_inner()
.next()
.await
.expect("steamed response")
.map_err(|e| e.to_string())?
.message_response
.ok_or("No reflection response".to_string())
.map(|resp| to_v1_msg_response(resp))
} else {
let request = Request::new(tokio_stream::once(reflection_request));
let resp = self.client_v1.server_reflection_info(request).await;
match resp {
Ok(r) => Ok(r),
Err(e) => match e.code().clone() {
tonic::Code::Unimplemented => {
// If v1 fails, change to v1alpha and try again
debug!("gRPC schema reflection falling back to v1alpha");
self.use_v1alpha = true;
return self.send_reflection_request(message).await;
}
_ => Err(e),
},
}
.map_err(|e| match e.code() {
tonic::Code::Unavailable => "Failed to connect to endpoint".to_string(),
tonic::Code::Unauthenticated => "Authentication failed".to_string(),
tonic::Code::DeadlineExceeded => "Deadline exceeded".to_string(),
_ => e.to_string(),
})?
.into_inner()
.next()
.await
.expect("steamed response")
.map_err(|e| e.to_string())?
.message_response
.ok_or("No reflection response".to_string())
}
}
}
fn to_v1_msg_response(
response: v1alpha::server_reflection_response::MessageResponse,
) -> MessageResponse {
match response {
v1alpha::server_reflection_response::MessageResponse::FileDescriptorResponse(v) => {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: v.file_descriptor_proto,
})
}
v1alpha::server_reflection_response::MessageResponse::AllExtensionNumbersResponse(v) => {
MessageResponse::AllExtensionNumbersResponse(ExtensionNumberResponse {
extension_number: v.extension_number,
base_type_name: v.base_type_name,
})
}
v1alpha::server_reflection_response::MessageResponse::ListServicesResponse(v) => {
MessageResponse::ListServicesResponse(ListServiceResponse {
service: v
.service
.iter()
.map(|s| ServiceResponse {
name: s.name.clone(),
})
.collect(),
})
}
v1alpha::server_reflection_response::MessageResponse::ErrorResponse(v) => {
MessageResponse::ErrorResponse(ErrorResponse {
error_code: v.error_code,
error_message: v.error_message,
})
}
}
}
fn to_v1alpha_request(request: ServerReflectionRequest) -> v1alpha::ServerReflectionRequest {
v1alpha::ServerReflectionRequest {
host: request.host,
message_request: request.message_request.map(|m| to_v1alpha_msg_request(m)),
}
}
fn to_v1alpha_msg_request(
message: MessageRequest,
) -> v1alpha::server_reflection_request::MessageRequest {
match message {
MessageRequest::FileByFilename(v) => {
v1alpha::server_reflection_request::MessageRequest::FileByFilename(v)
}
MessageRequest::FileContainingSymbol(v) => {
v1alpha::server_reflection_request::MessageRequest::FileContainingSymbol(v)
}
MessageRequest::FileContainingExtension(ExtensionRequest {
extension_number,
containing_type,
}) => v1alpha::server_reflection_request::MessageRequest::FileContainingExtension(
v1alpha::ExtensionRequest {
extension_number,
containing_type,
},
),
MessageRequest::AllExtensionNumbersOfType(v) => {
v1alpha::server_reflection_request::MessageRequest::AllExtensionNumbersOfType(v)
}
MessageRequest::ListServices(v) => {
v1alpha::server_reflection_request::MessageRequest::ListServices(v)
}
}
}

View File

@@ -5,7 +5,9 @@ use serde_json::Deserializer;
mod codec;
mod json_schema;
pub mod manager;
mod proto;
mod reflection;
mod transport;
mod client;
pub use tonic::metadata::*;
pub use tonic::Code;

View File

@@ -2,9 +2,9 @@ use std::collections::BTreeMap;
use std::path::PathBuf;
use std::str::FromStr;
use hyper::client::HttpConnector;
use hyper::Client;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
pub use prost_reflect::DynamicMessage;
use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
use serde_json::Deserializer;
@@ -16,10 +16,11 @@ use tonic::transport::Uri;
use tonic::{IntoRequest, IntoStreamingRequest, Request, Response, Status, Streaming};
use crate::codec::DynamicCodec;
use crate::proto::{
fill_pool_from_files, fill_pool_from_reflection, get_transport, method_desc_to_path,
use crate::reflection::{
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path,
};
use crate::{json_schema, MethodDefinition, ServiceDefinition};
use crate::transport::get_transport;
#[derive(Clone)]
pub struct GrpcConnection {

View File

@@ -3,11 +3,9 @@ use std::ops::Deref;
use std::path::PathBuf;
use std::str::FromStr;
use crate::client::AutoReflectionClient;
use anyhow::anyhow;
use async_recursion::async_recursion;
use hyper::client::HttpConnector;
use hyper::Client;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use log::{debug, warn};
use prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor};
@@ -16,15 +14,10 @@ use tauri::path::BaseDirectory;
use tauri::{AppHandle, Manager};
use tauri_plugin_shell::ShellExt;
use tokio::fs;
use tokio_stream::StreamExt;
use tonic::body::BoxBody;
use tonic::codegen::http::uri::PathAndQuery;
use tonic::transport::Uri;
use tonic::Request;
use tonic_reflection::pb::server_reflection_client::ServerReflectionClient;
use tonic_reflection::pb::server_reflection_request::MessageRequest;
use tonic_reflection::pb::server_reflection_response::MessageResponse;
use tonic_reflection::pb::ServerReflectionRequest;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
pub async fn fill_pool_from_files(
app_handle: &AppHandle,
@@ -98,7 +91,7 @@ pub async fn fill_pool_from_files(
pub async fn fill_pool_from_reflection(uri: &Uri) -> Result<DescriptorPool, String> {
let mut pool = DescriptorPool::new();
let mut client = ServerReflectionClient::with_origin(get_transport(), uri.clone());
let mut client = AutoReflectionClient::new(uri);
for service in list_services(&mut client).await? {
if service == "grpc.reflection.v1alpha.ServerReflection" {
@@ -114,21 +107,8 @@ pub async fn fill_pool_from_reflection(uri: &Uri) -> Result<DescriptorPool, Stri
Ok(pool)
}
pub fn get_transport() -> Client<HttpsConnector<HttpConnector>, BoxBody> {
let connector = HttpsConnectorBuilder::new().with_native_roots();
let connector = connector.https_or_http().enable_http2().wrap_connector({
let mut http_connector = HttpConnector::new();
http_connector.enforce_http(false);
http_connector
});
Client::builder().pool_max_idle_per_host(0).http2_only(true).build(connector)
}
async fn list_services(
reflect_client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
) -> Result<Vec<String>, String> {
let response =
send_reflection_request(reflect_client, MessageRequest::ListServices("".into())).await?;
async fn list_services(client: &mut AutoReflectionClient) -> Result<Vec<String>, String> {
let response = client.send_reflection_request(MessageRequest::ListServices("".into())).await?;
let list_services_response = match response {
MessageResponse::ListServicesResponse(resp) => resp,
@@ -141,13 +121,11 @@ async fn list_services(
async fn file_descriptor_set_from_service_name(
service_name: &str,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
client: &mut AutoReflectionClient,
) {
let response = match send_reflection_request(
client,
MessageRequest::FileContainingSymbol(service_name.into()),
)
.await
let response = match client
.send_reflection_request(MessageRequest::FileContainingSymbol(service_name.into()))
.await
{
Ok(resp) => resp,
Err(e) => {
@@ -169,7 +147,7 @@ async fn file_descriptor_set_from_service_name(
async fn add_file_descriptors_to_pool(
fds: Vec<Vec<u8>>,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
client: &mut AutoReflectionClient,
) {
let mut topo_sort = topology::SimpleTopoSort::new();
let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len());
@@ -198,15 +176,15 @@ async fn add_file_descriptors_to_pool(
async fn file_descriptor_set_by_filename(
filename: &str,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
client: &mut AutoReflectionClient,
) {
// We already fetched this file
if let Some(_) = pool.get_file_by_name(filename) {
return;
}
let response =
send_reflection_request(client, MessageRequest::FileByFilename(filename.into())).await;
let msg = MessageRequest::FileByFilename(filename.into());
let response = client.send_reflection_request(msg).await;
let file_descriptor_response = match response {
Ok(MessageResponse::FileDescriptorResponse(resp)) => resp,
Ok(_) => {
@@ -222,35 +200,6 @@ async fn file_descriptor_set_by_filename(
.await;
}
async fn send_reflection_request(
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
message: MessageRequest,
) -> Result<MessageResponse, String> {
let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter
message_request: Some(message),
};
let request = Request::new(tokio_stream::once(reflection_request));
client
.server_reflection_info(request)
.await
.map_err(|e| match e.code() {
tonic::Code::Unavailable => "Failed to connect to endpoint".to_string(),
tonic::Code::Unauthenticated => "Authentication failed".to_string(),
tonic::Code::DeadlineExceeded => "Deadline exceeded".to_string(),
_ => e.to_string(),
})?
.into_inner()
.next()
.await
.expect("steamed response")
.map_err(|e| e.to_string())?
.message_response
.ok_or("No reflection response".to_string())
}
pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery {
let full_name = md.full_name();
let (namespace, method_name) = full_name
@@ -292,8 +241,8 @@ mod topology {
where
T: Eq + std::hash::Hash + Clone,
{
type IntoIter = SimpleTopoSortIter<T>;
type Item = <SimpleTopoSortIter<T> as Iterator>::Item;
type IntoIter = SimpleTopoSortIter<T>;
fn into_iter(self) -> Self::IntoIter {
SimpleTopoSortIter::new(self)

View File

@@ -0,0 +1,19 @@
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
use tonic::body::BoxBody;
pub(crate) fn get_transport() -> Client<HttpsConnector<HttpConnector>, BoxBody> {
let connector = HttpsConnectorBuilder::new().with_platform_verifier();
let connector = connector.https_or_http().enable_http2().wrap_connector({
let mut http_connector = HttpConnector::new();
http_connector.enforce_http(false);
http_connector
});
Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(0)
.http2_only(true)
.build(connector)
}