Files
yaak/src-tauri/yaak-grpc/src/client.rs
2025-01-17 05:53:03 -08:00

173 lines
6.6 KiB
Rust

use crate::transport::get_transport;
use async_recursion::async_recursion;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use log::debug;
use tokio_stream::StreamExt;
use tonic::body::BoxBody;
use tonic::transport::Uri;
use tonic::Request;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
use tonic_reflection::pb::v1::{
ErrorResponse, ExtensionNumberResponse, ListServiceResponse, ServerReflectionRequest,
ServiceResponse,
};
use tonic_reflection::pb::v1::{ExtensionRequest, FileDescriptorResponse};
use tonic_reflection::pb::{v1, v1alpha};
pub struct AutoReflectionClient<T = Client<HttpsConnector<HttpConnector>, BoxBody>> {
use_v1alpha: bool,
client_v1: v1::server_reflection_client::ServerReflectionClient<T>,
client_v1alpha: v1alpha::server_reflection_client::ServerReflectionClient<T>,
}
impl AutoReflectionClient {
pub fn new(uri: &Uri) -> Self {
let client_v1 = v1::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(),
uri.clone(),
);
let client_v1alpha = v1alpha::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(),
uri.clone(),
);
AutoReflectionClient {
use_v1alpha: false,
client_v1,
client_v1alpha,
}
}
#[async_recursion]
pub async fn send_reflection_request(
&mut self,
message: MessageRequest,
) -> Result<MessageResponse, String> {
let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter
message_request: Some(message.clone()),
};
if self.use_v1alpha {
let request = Request::new(tokio_stream::once(to_v1alpha_request(reflection_request)));
self.client_v1alpha
.server_reflection_info(request)
.await
.map_err(|e| match e.code() {
tonic::Code::Unavailable => "Failed to connect to endpoint".to_string(),
tonic::Code::Unauthenticated => "Authentication failed".to_string(),
tonic::Code::DeadlineExceeded => "Deadline exceeded".to_string(),
_ => e.to_string(),
})?
.into_inner()
.next()
.await
.expect("steamed response")
.map_err(|e| e.to_string())?
.message_response
.ok_or("No reflection response".to_string())
.map(|resp| to_v1_msg_response(resp))
} else {
let request = Request::new(tokio_stream::once(reflection_request));
let resp = self.client_v1.server_reflection_info(request).await;
match resp {
Ok(r) => Ok(r),
Err(e) => match e.code().clone() {
tonic::Code::Unimplemented => {
// If v1 fails, change to v1alpha and try again
debug!("gRPC schema reflection falling back to v1alpha");
self.use_v1alpha = true;
return self.send_reflection_request(message).await;
}
_ => Err(e),
},
}
.map_err(|e| match e.code() {
tonic::Code::Unavailable => "Failed to connect to endpoint".to_string(),
tonic::Code::Unauthenticated => "Authentication failed".to_string(),
tonic::Code::DeadlineExceeded => "Deadline exceeded".to_string(),
_ => e.to_string(),
})?
.into_inner()
.next()
.await
.expect("steamed response")
.map_err(|e| e.to_string())?
.message_response
.ok_or("No reflection response".to_string())
}
}
}
fn to_v1_msg_response(
response: v1alpha::server_reflection_response::MessageResponse,
) -> MessageResponse {
match response {
v1alpha::server_reflection_response::MessageResponse::FileDescriptorResponse(v) => {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: v.file_descriptor_proto,
})
}
v1alpha::server_reflection_response::MessageResponse::AllExtensionNumbersResponse(v) => {
MessageResponse::AllExtensionNumbersResponse(ExtensionNumberResponse {
extension_number: v.extension_number,
base_type_name: v.base_type_name,
})
}
v1alpha::server_reflection_response::MessageResponse::ListServicesResponse(v) => {
MessageResponse::ListServicesResponse(ListServiceResponse {
service: v
.service
.iter()
.map(|s| ServiceResponse {
name: s.name.clone(),
})
.collect(),
})
}
v1alpha::server_reflection_response::MessageResponse::ErrorResponse(v) => {
MessageResponse::ErrorResponse(ErrorResponse {
error_code: v.error_code,
error_message: v.error_message,
})
}
}
}
fn to_v1alpha_request(request: ServerReflectionRequest) -> v1alpha::ServerReflectionRequest {
v1alpha::ServerReflectionRequest {
host: request.host,
message_request: request.message_request.map(|m| to_v1alpha_msg_request(m)),
}
}
fn to_v1alpha_msg_request(
message: MessageRequest,
) -> v1alpha::server_reflection_request::MessageRequest {
match message {
MessageRequest::FileByFilename(v) => {
v1alpha::server_reflection_request::MessageRequest::FileByFilename(v)
}
MessageRequest::FileContainingSymbol(v) => {
v1alpha::server_reflection_request::MessageRequest::FileContainingSymbol(v)
}
MessageRequest::FileContainingExtension(ExtensionRequest {
extension_number,
containing_type,
}) => v1alpha::server_reflection_request::MessageRequest::FileContainingExtension(
v1alpha::ExtensionRequest {
extension_number,
containing_type,
},
),
MessageRequest::AllExtensionNumbersOfType(v) => {
v1alpha::server_reflection_request::MessageRequest::AllExtensionNumbersOfType(v)
}
MessageRequest::ListServices(v) => {
v1alpha::server_reflection_request::MessageRequest::ListServices(v)
}
}
}