mirror of
https://github.com/mountain-loop/yaak.git
synced 2026-04-24 01:28:35 +02:00
Decouple core Yaak logic from Tauri (#354)
This commit is contained in:
60
crates/yaak-grpc/src/any.rs
Normal file
60
crates/yaak-grpc/src/any.rs
Normal file
@@ -0,0 +1,60 @@
|
||||
use log::error;
|
||||
|
||||
pub(crate) fn collect_any_types(json: &str, out: &mut Vec<String>) {
|
||||
let value = match serde_json::from_str(json).map_err(|e| e.to_string()) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("Failed to parse gRPC message JSON: {e:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
collect_any_types_value(&value, out);
|
||||
}
|
||||
|
||||
fn collect_any_types_value(json: &serde_json::Value, out: &mut Vec<String>) {
|
||||
match json {
|
||||
serde_json::Value::Object(map) => {
|
||||
if let Some(t) = map.get("@type").and_then(|v| v.as_str()) {
|
||||
if let Some(full_name) = t.rsplit_once('/').map(|(_, n)| n) {
|
||||
out.push(full_name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
for v in map.values() {
|
||||
collect_any_types_value(v, out);
|
||||
}
|
||||
}
|
||||
serde_json::Value::Array(arr) => {
|
||||
for v in arr {
|
||||
collect_any_types_value(v, out);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Write tests for this
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn test_collect_any_types() {
|
||||
let json = r#"{
|
||||
"mounts": [
|
||||
{
|
||||
"mountSource": {
|
||||
"@type": "type.googleapis.com/mount_source.MountSourceRBDVolume",
|
||||
"volumeID": "volumes/rbd"
|
||||
}
|
||||
}
|
||||
],
|
||||
"foo": {
|
||||
"@type": "type.googleapis.com/foo.bar",
|
||||
"foo": "fooo"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let mut out = Vec::new();
|
||||
super::collect_any_types(json, &mut out);
|
||||
assert_eq!(out, vec!["foo.bar", "mount_source.MountSourceRBDVolume"]);
|
||||
}
|
||||
}
|
||||
182
crates/yaak-grpc/src/client.rs
Normal file
182
crates/yaak-grpc/src/client.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use crate::error::Error::GenericError;
|
||||
use crate::error::Result;
|
||||
use crate::manager::decorate_req;
|
||||
use crate::transport::get_transport;
|
||||
use async_recursion::async_recursion;
|
||||
use hyper_rustls::HttpsConnector;
|
||||
use hyper_util::client::legacy::Client;
|
||||
use hyper_util::client::legacy::connect::HttpConnector;
|
||||
use log::debug;
|
||||
use std::collections::BTreeMap;
|
||||
use tokio_stream::StreamExt;
|
||||
use tonic::Request;
|
||||
use tonic::body::BoxBody;
|
||||
use tonic::transport::Uri;
|
||||
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
|
||||
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
|
||||
use tonic_reflection::pb::v1::{
|
||||
ErrorResponse, ExtensionNumberResponse, ListServiceResponse, ServerReflectionRequest,
|
||||
ServiceResponse,
|
||||
};
|
||||
use tonic_reflection::pb::v1::{ExtensionRequest, FileDescriptorResponse};
|
||||
use tonic_reflection::pb::{v1, v1alpha};
|
||||
use yaak_tls::ClientCertificateConfig;
|
||||
|
||||
pub struct AutoReflectionClient<T = Client<HttpsConnector<HttpConnector>, BoxBody>> {
|
||||
use_v1alpha: bool,
|
||||
client_v1: v1::server_reflection_client::ServerReflectionClient<T>,
|
||||
client_v1alpha: v1alpha::server_reflection_client::ServerReflectionClient<T>,
|
||||
}
|
||||
|
||||
impl AutoReflectionClient {
|
||||
pub fn new(
|
||||
uri: &Uri,
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<Self> {
|
||||
let client_v1 = v1::server_reflection_client::ServerReflectionClient::with_origin(
|
||||
get_transport(validate_certificates, client_cert.clone())?,
|
||||
uri.clone(),
|
||||
);
|
||||
let client_v1alpha = v1alpha::server_reflection_client::ServerReflectionClient::with_origin(
|
||||
get_transport(validate_certificates, client_cert.clone())?,
|
||||
uri.clone(),
|
||||
);
|
||||
Ok(AutoReflectionClient { use_v1alpha: false, client_v1, client_v1alpha })
|
||||
}
|
||||
|
||||
#[async_recursion]
|
||||
pub async fn send_reflection_request(
|
||||
&mut self,
|
||||
message: MessageRequest,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
) -> Result<MessageResponse> {
|
||||
let reflection_request = ServerReflectionRequest {
|
||||
host: "".into(), // Doesn't matter
|
||||
message_request: Some(message.clone()),
|
||||
};
|
||||
|
||||
if self.use_v1alpha {
|
||||
let mut request =
|
||||
Request::new(tokio_stream::once(to_v1alpha_request(reflection_request)));
|
||||
decorate_req(metadata, &mut request)?;
|
||||
|
||||
self.client_v1alpha
|
||||
.server_reflection_info(request)
|
||||
.await
|
||||
.map_err(|e| match e.code() {
|
||||
tonic::Code::Unavailable => {
|
||||
GenericError("Failed to connect to endpoint".to_string())
|
||||
}
|
||||
tonic::Code::Unauthenticated => {
|
||||
GenericError("Authentication failed".to_string())
|
||||
}
|
||||
tonic::Code::DeadlineExceeded => GenericError("Deadline exceeded".to_string()),
|
||||
_ => GenericError(e.to_string()),
|
||||
})?
|
||||
.into_inner()
|
||||
.next()
|
||||
.await
|
||||
.ok_or(GenericError("Missing reflection message".to_string()))??
|
||||
.message_response
|
||||
.ok_or(GenericError("No reflection response".to_string()))
|
||||
.map(|resp| to_v1_msg_response(resp))
|
||||
} else {
|
||||
let mut request = Request::new(tokio_stream::once(reflection_request));
|
||||
decorate_req(metadata, &mut request)?;
|
||||
|
||||
let resp = self.client_v1.server_reflection_info(request).await;
|
||||
match resp {
|
||||
Ok(r) => Ok(r),
|
||||
Err(e) => match e.code().clone() {
|
||||
tonic::Code::Unimplemented => {
|
||||
// If v1 fails, change to v1alpha and try again
|
||||
debug!("gRPC schema reflection falling back to v1alpha");
|
||||
self.use_v1alpha = true;
|
||||
return self.send_reflection_request(message, metadata).await;
|
||||
}
|
||||
_ => Err(e),
|
||||
},
|
||||
}
|
||||
.map_err(|e| match e.code() {
|
||||
tonic::Code::Unavailable => {
|
||||
GenericError("Failed to connect to endpoint".to_string())
|
||||
}
|
||||
tonic::Code::Unauthenticated => GenericError("Authentication failed".to_string()),
|
||||
tonic::Code::DeadlineExceeded => GenericError("Deadline exceeded".to_string()),
|
||||
_ => GenericError(e.to_string()),
|
||||
})?
|
||||
.into_inner()
|
||||
.next()
|
||||
.await
|
||||
.ok_or(GenericError("Missing reflection message".to_string()))??
|
||||
.message_response
|
||||
.ok_or(GenericError("No reflection response".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_v1_msg_response(
|
||||
response: v1alpha::server_reflection_response::MessageResponse,
|
||||
) -> MessageResponse {
|
||||
match response {
|
||||
v1alpha::server_reflection_response::MessageResponse::FileDescriptorResponse(v) => {
|
||||
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
|
||||
file_descriptor_proto: v.file_descriptor_proto,
|
||||
})
|
||||
}
|
||||
v1alpha::server_reflection_response::MessageResponse::AllExtensionNumbersResponse(v) => {
|
||||
MessageResponse::AllExtensionNumbersResponse(ExtensionNumberResponse {
|
||||
extension_number: v.extension_number,
|
||||
base_type_name: v.base_type_name,
|
||||
})
|
||||
}
|
||||
v1alpha::server_reflection_response::MessageResponse::ListServicesResponse(v) => {
|
||||
MessageResponse::ListServicesResponse(ListServiceResponse {
|
||||
service: v
|
||||
.service
|
||||
.iter()
|
||||
.map(|s| ServiceResponse { name: s.name.clone() })
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
v1alpha::server_reflection_response::MessageResponse::ErrorResponse(v) => {
|
||||
MessageResponse::ErrorResponse(ErrorResponse {
|
||||
error_code: v.error_code,
|
||||
error_message: v.error_message,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn to_v1alpha_request(request: ServerReflectionRequest) -> v1alpha::ServerReflectionRequest {
|
||||
v1alpha::ServerReflectionRequest {
|
||||
host: request.host,
|
||||
message_request: request.message_request.map(|m| to_v1alpha_msg_request(m)),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_v1alpha_msg_request(
|
||||
message: MessageRequest,
|
||||
) -> v1alpha::server_reflection_request::MessageRequest {
|
||||
match message {
|
||||
MessageRequest::FileByFilename(v) => {
|
||||
v1alpha::server_reflection_request::MessageRequest::FileByFilename(v)
|
||||
}
|
||||
MessageRequest::FileContainingSymbol(v) => {
|
||||
v1alpha::server_reflection_request::MessageRequest::FileContainingSymbol(v)
|
||||
}
|
||||
MessageRequest::FileContainingExtension(ExtensionRequest {
|
||||
extension_number,
|
||||
containing_type,
|
||||
}) => v1alpha::server_reflection_request::MessageRequest::FileContainingExtension(
|
||||
v1alpha::ExtensionRequest { extension_number, containing_type },
|
||||
),
|
||||
MessageRequest::AllExtensionNumbersOfType(v) => {
|
||||
v1alpha::server_reflection_request::MessageRequest::AllExtensionNumbersOfType(v)
|
||||
}
|
||||
MessageRequest::ListServices(v) => {
|
||||
v1alpha::server_reflection_request::MessageRequest::ListServices(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
50
crates/yaak-grpc/src/codec.rs
Normal file
50
crates/yaak-grpc/src/codec.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use prost_reflect::prost::Message;
|
||||
use prost_reflect::{DynamicMessage, MethodDescriptor};
|
||||
use tonic::Status;
|
||||
use tonic::codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DynamicCodec(MethodDescriptor);
|
||||
|
||||
impl DynamicCodec {
|
||||
#[allow(dead_code)]
|
||||
pub fn new(md: MethodDescriptor) -> Self {
|
||||
Self(md)
|
||||
}
|
||||
}
|
||||
|
||||
impl Codec for DynamicCodec {
|
||||
type Encode = DynamicMessage;
|
||||
type Decode = DynamicMessage;
|
||||
type Encoder = Self;
|
||||
type Decoder = Self;
|
||||
|
||||
fn encoder(&mut self) -> Self::Encoder {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn decoder(&mut self) -> Self::Decoder {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for DynamicCodec {
|
||||
type Item = DynamicMessage;
|
||||
type Error = Status;
|
||||
|
||||
fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
|
||||
item.encode(dst).expect("buffer is too small to decode this message");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for DynamicCodec {
|
||||
type Item = DynamicMessage;
|
||||
type Error = Status;
|
||||
|
||||
fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
|
||||
let mut msg = DynamicMessage::new(self.0.output());
|
||||
msg.merge(src).map_err(|err| Status::internal(err.to_string()))?;
|
||||
Ok(Some(msg))
|
||||
}
|
||||
}
|
||||
51
crates/yaak-grpc/src/error.rs
Normal file
51
crates/yaak-grpc/src/error.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use crate::manager::GrpcStreamError;
|
||||
use prost::DecodeError;
|
||||
use serde::{Serialize, Serializer};
|
||||
use serde_json::Error as SerdeJsonError;
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use tonic::Status;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
TlsError(#[from] yaak_tls::error::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
TonicError(#[from] Status),
|
||||
|
||||
#[error("Prost reflect error: {0:?}")]
|
||||
ProstReflectError(#[from] prost_reflect::DescriptorError),
|
||||
|
||||
#[error(transparent)]
|
||||
DeserializerError(#[from] SerdeJsonError),
|
||||
|
||||
#[error(transparent)]
|
||||
GrpcStreamError(#[from] GrpcStreamError),
|
||||
|
||||
#[error(transparent)]
|
||||
GrpcDecodeError(#[from] DecodeError),
|
||||
|
||||
#[error(transparent)]
|
||||
GrpcInvalidMetadataKeyError(#[from] tonic::metadata::errors::InvalidMetadataKey),
|
||||
|
||||
#[error(transparent)]
|
||||
GrpcInvalidMetadataValueError(#[from] tonic::metadata::errors::InvalidMetadataValue),
|
||||
|
||||
#[error(transparent)]
|
||||
IOError(#[from] io::Error),
|
||||
|
||||
#[error("GRPC error: {0}")]
|
||||
GenericError(String),
|
||||
}
|
||||
|
||||
impl Serialize for Error {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(self.to_string().as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
382
crates/yaak-grpc/src/json_schema.rs
Normal file
382
crates/yaak-grpc/src/json_schema.rs
Normal file
@@ -0,0 +1,382 @@
|
||||
use prost_reflect::{DescriptorPool, FieldDescriptor, MessageDescriptor};
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
|
||||
pub fn message_to_json_schema(_: &DescriptorPool, root_msg: MessageDescriptor) -> JsonSchemaEntry {
|
||||
JsonSchemaGenerator::generate_json_schema(root_msg)
|
||||
}
|
||||
|
||||
struct JsonSchemaGenerator {
|
||||
msg_mapping: HashMap<String, JsonSchemaEntry>,
|
||||
}
|
||||
|
||||
impl JsonSchemaGenerator {
|
||||
pub fn new() -> Self {
|
||||
JsonSchemaGenerator { msg_mapping: HashMap::new() }
|
||||
}
|
||||
|
||||
pub fn generate_json_schema(msg: MessageDescriptor) -> JsonSchemaEntry {
|
||||
let generator = JsonSchemaGenerator::new();
|
||||
generator.scan_root(msg)
|
||||
}
|
||||
|
||||
fn add_message(&mut self, msg: &MessageDescriptor) {
|
||||
let name = msg.full_name().to_string();
|
||||
if self.msg_mapping.contains_key(&name) {
|
||||
return;
|
||||
}
|
||||
self.msg_mapping.insert(name.clone(), JsonSchemaEntry::object());
|
||||
}
|
||||
|
||||
pub fn scan_root(mut self, root_msg: MessageDescriptor) -> JsonSchemaEntry {
|
||||
self.init_structure(root_msg.clone());
|
||||
self.fill_properties(root_msg.clone());
|
||||
|
||||
let mut root = self.msg_mapping.remove(root_msg.full_name()).unwrap();
|
||||
|
||||
if self.msg_mapping.len() > 0 {
|
||||
root.defs = Some(self.msg_mapping);
|
||||
}
|
||||
root
|
||||
}
|
||||
|
||||
fn fill_properties(&mut self, root_msg: MessageDescriptor) {
|
||||
let root_name = root_msg.full_name().to_string();
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
let mut msg_queue = VecDeque::new();
|
||||
msg_queue.push_back(root_msg);
|
||||
|
||||
while !msg_queue.is_empty() {
|
||||
let msg = msg_queue.pop_front().unwrap();
|
||||
let msg_name = msg.full_name();
|
||||
if visited.contains(msg_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
visited.insert(msg_name.to_string());
|
||||
|
||||
let entry = self.msg_mapping.get_mut(msg_name).unwrap();
|
||||
|
||||
for field in msg.fields() {
|
||||
let field_name = field.name().to_string();
|
||||
|
||||
if matches!(field.cardinality(), prost_reflect::Cardinality::Required) {
|
||||
entry.add_required(field_name.clone());
|
||||
}
|
||||
|
||||
if let Some(oneof) = field.containing_oneof() {
|
||||
for oneof_field in oneof.fields() {
|
||||
if let Some(fm) = is_message_field(&oneof_field) {
|
||||
msg_queue.push_back(fm);
|
||||
}
|
||||
entry.add_property(
|
||||
oneof_field.name().to_string(),
|
||||
field_to_type_or_ref(&root_name, oneof_field),
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let (field_type, nest_msg) = {
|
||||
if let Some(fm) = is_message_field(&field) {
|
||||
if field.is_list() {
|
||||
// repeated message type
|
||||
(
|
||||
JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)),
|
||||
Some(fm),
|
||||
)
|
||||
} else if field.is_map() {
|
||||
let value_field = fm.get_field_by_name("value").unwrap();
|
||||
|
||||
if let Some(fm) = is_message_field(&value_field) {
|
||||
(
|
||||
JsonSchemaEntry::map(field_to_type_or_ref(
|
||||
&root_name,
|
||||
value_field,
|
||||
)),
|
||||
Some(fm),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
JsonSchemaEntry::map(field_to_type_or_ref(
|
||||
&root_name,
|
||||
value_field,
|
||||
)),
|
||||
None,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
(field_to_type_or_ref(&root_name, field), Some(fm))
|
||||
}
|
||||
} else {
|
||||
if field.is_list() {
|
||||
// repeated scalar type
|
||||
(JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)), None)
|
||||
} else {
|
||||
(field_to_type_or_ref(&root_name, field), None)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(fm) = nest_msg {
|
||||
msg_queue.push_back(fm);
|
||||
}
|
||||
|
||||
entry.add_property(field_name, field_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn init_structure(&mut self, root_msg: MessageDescriptor) {
|
||||
let mut visited = HashSet::new();
|
||||
let mut msg_queue = VecDeque::new();
|
||||
msg_queue.push_back(root_msg.clone());
|
||||
|
||||
// level traversal, to make sure all message type is defined before used
|
||||
while !msg_queue.is_empty() {
|
||||
let msg = msg_queue.pop_front().unwrap();
|
||||
let name = msg.full_name();
|
||||
if visited.contains(name) {
|
||||
continue;
|
||||
}
|
||||
visited.insert(name.to_string());
|
||||
self.add_message(&msg);
|
||||
|
||||
for child in msg.child_messages() {
|
||||
if child.is_map_entry() {
|
||||
// for field with map<key, value> type, there will be a child message type *Entry generated
|
||||
// just skip it
|
||||
continue;
|
||||
}
|
||||
|
||||
self.add_message(&child);
|
||||
msg_queue.push_back(child);
|
||||
}
|
||||
|
||||
for field in msg.fields() {
|
||||
if let Some(oneof) = field.containing_oneof() {
|
||||
for oneof_field in oneof.fields() {
|
||||
if let Some(fm) = is_message_field(&oneof_field) {
|
||||
self.add_message(&fm);
|
||||
msg_queue.push_back(fm);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if field.is_map() {
|
||||
// key is always scalar type, so no need to process
|
||||
// value can be any type, so need to unpack value type
|
||||
let map_field_msg = is_message_field(&field).unwrap();
|
||||
let map_value_field = map_field_msg.get_field_by_name("value").unwrap();
|
||||
if let Some(value_fm) = is_message_field(&map_value_field) {
|
||||
self.add_message(&value_fm);
|
||||
msg_queue.push_back(value_fm);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if let Some(fm) = is_message_field(&field) {
|
||||
self.add_message(&fm);
|
||||
msg_queue.push_back(fm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn field_to_type_or_ref(root_name: &str, field: FieldDescriptor) -> JsonSchemaEntry {
|
||||
match field.kind() {
|
||||
prost_reflect::Kind::Bool => JsonSchemaEntry::boolean(),
|
||||
prost_reflect::Kind::Double => JsonSchemaEntry::number("double"),
|
||||
prost_reflect::Kind::Float => JsonSchemaEntry::number("float"),
|
||||
prost_reflect::Kind::Int32 => JsonSchemaEntry::number("int32"),
|
||||
prost_reflect::Kind::Int64 => JsonSchemaEntry::string_with_format("int64"),
|
||||
prost_reflect::Kind::Uint32 => JsonSchemaEntry::number("int64"),
|
||||
prost_reflect::Kind::Uint64 => JsonSchemaEntry::string_with_format("uint64"),
|
||||
prost_reflect::Kind::Sint32 => JsonSchemaEntry::number("sint32"),
|
||||
prost_reflect::Kind::Sint64 => JsonSchemaEntry::string_with_format("sint64"),
|
||||
prost_reflect::Kind::Fixed32 => JsonSchemaEntry::number("int64"),
|
||||
prost_reflect::Kind::Fixed64 => JsonSchemaEntry::string_with_format("fixed64"),
|
||||
prost_reflect::Kind::Sfixed32 => JsonSchemaEntry::number("sfixed32"),
|
||||
prost_reflect::Kind::Sfixed64 => JsonSchemaEntry::string_with_format("sfixed64"),
|
||||
prost_reflect::Kind::String => JsonSchemaEntry::string(),
|
||||
prost_reflect::Kind::Bytes => JsonSchemaEntry::string_with_format("byte"),
|
||||
prost_reflect::Kind::Enum(enums) => {
|
||||
let values = enums.values().map(|v| v.name().to_string()).collect::<Vec<_>>();
|
||||
JsonSchemaEntry::enums(values)
|
||||
}
|
||||
prost_reflect::Kind::Message(fm) => {
|
||||
let field_type_full_name = fm.full_name();
|
||||
match field_type_full_name {
|
||||
// [Protocol Buffers Well-Known Types]: https://protobuf.dev/reference/protobuf/google.protobuf/
|
||||
"google.protobuf.FieldMask" => JsonSchemaEntry::string(),
|
||||
"google.protobuf.Timestamp" => JsonSchemaEntry::string_with_format("date-time"),
|
||||
"google.protobuf.Duration" => JsonSchemaEntry::string(),
|
||||
"google.protobuf.StringValue" => JsonSchemaEntry::string(),
|
||||
"google.protobuf.BytesValue" => JsonSchemaEntry::string_with_format("byte"),
|
||||
"google.protobuf.Int32Value" => JsonSchemaEntry::number("int32"),
|
||||
"google.protobuf.UInt32Value" => JsonSchemaEntry::string_with_format("int64"),
|
||||
"google.protobuf.Int64Value" => JsonSchemaEntry::string_with_format("int64"),
|
||||
"google.protobuf.UInt64Value" => JsonSchemaEntry::string_with_format("uint64"),
|
||||
"google.protobuf.FloatValue" => JsonSchemaEntry::number("float"),
|
||||
"google.protobuf.DoubleValue" => JsonSchemaEntry::number("double"),
|
||||
"google.protobuf.BoolValue" => JsonSchemaEntry::boolean(),
|
||||
"google.protobuf.Empty" => JsonSchemaEntry::default(),
|
||||
"google.protobuf.Struct" => JsonSchemaEntry::object(),
|
||||
"google.protobuf.ListValue" => JsonSchemaEntry::array(JsonSchemaEntry::default()),
|
||||
"google.protobuf.NullValue" => JsonSchemaEntry::null(),
|
||||
name @ _ if name == root_name => JsonSchemaEntry::root_reference(),
|
||||
_ => JsonSchemaEntry::reference(fm.full_name()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_message_field(field: &FieldDescriptor) -> Option<MessageDescriptor> {
|
||||
match field.kind() {
|
||||
prost_reflect::Kind::Message(m) => Some(m),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, serde::Serialize)]
|
||||
#[serde(default, rename_all = "camelCase")]
|
||||
pub struct JsonSchemaEntry {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
title: Option<String>,
|
||||
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
type_: Option<JsonType>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
format: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
properties: Option<HashMap<String, JsonSchemaEntry>>,
|
||||
|
||||
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
|
||||
enum_: Option<Vec<String>>,
|
||||
|
||||
// for map type
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
additional_properties: Option<Box<JsonSchemaEntry>>,
|
||||
|
||||
// Set all properties to required
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
required: Option<Vec<String>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
items: Option<Box<JsonSchemaEntry>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "$defs")]
|
||||
defs: Option<HashMap<String, JsonSchemaEntry>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "$ref")]
|
||||
ref_: Option<String>,
|
||||
}
|
||||
|
||||
impl JsonSchemaEntry {
|
||||
pub fn add_property(&mut self, name: String, entry: JsonSchemaEntry) {
|
||||
if self.properties.is_none() {
|
||||
self.properties = Some(HashMap::new());
|
||||
}
|
||||
self.properties.as_mut().unwrap().insert(name, entry);
|
||||
}
|
||||
|
||||
pub fn add_required(&mut self, name: String) {
|
||||
if self.required.is_none() {
|
||||
self.required = Some(Vec::new());
|
||||
}
|
||||
self.required.as_mut().unwrap().push(name);
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonSchemaEntry {
|
||||
pub fn object() -> Self {
|
||||
JsonSchemaEntry { type_: Some(JsonType::Object), ..Default::default() }
|
||||
}
|
||||
pub fn boolean() -> Self {
|
||||
JsonSchemaEntry { type_: Some(JsonType::Boolean), ..Default::default() }
|
||||
}
|
||||
pub fn number<S: Into<String>>(format: S) -> Self {
|
||||
JsonSchemaEntry {
|
||||
type_: Some(JsonType::Number),
|
||||
format: Some(format.into()),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn string() -> Self {
|
||||
JsonSchemaEntry { type_: Some(JsonType::String), ..Default::default() }
|
||||
}
|
||||
|
||||
pub fn string_with_format<S: Into<String>>(format: S) -> Self {
|
||||
JsonSchemaEntry {
|
||||
type_: Some(JsonType::String),
|
||||
format: Some(format.into()),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn reference<S: AsRef<str>>(ref_: S) -> Self {
|
||||
JsonSchemaEntry { ref_: Some(format!("#/$defs/{}", ref_.as_ref())), ..Default::default() }
|
||||
}
|
||||
pub fn root_reference() -> Self {
|
||||
JsonSchemaEntry { ref_: Some("#".to_string()), ..Default::default() }
|
||||
}
|
||||
pub fn array(item: JsonSchemaEntry) -> Self {
|
||||
JsonSchemaEntry {
|
||||
type_: Some(JsonType::Array),
|
||||
items: Some(Box::new(item)),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn enums(enums: Vec<String>) -> Self {
|
||||
JsonSchemaEntry { type_: Some(JsonType::String), enum_: Some(enums), ..Default::default() }
|
||||
}
|
||||
|
||||
pub fn map(value_type: JsonSchemaEntry) -> Self {
|
||||
JsonSchemaEntry {
|
||||
type_: Some(JsonType::Object),
|
||||
additional_properties: Some(Box::new(value_type)),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn null() -> Self {
|
||||
JsonSchemaEntry { type_: Some(JsonType::Null), ..Default::default() }
|
||||
}
|
||||
}
|
||||
|
||||
enum JsonType {
|
||||
String,
|
||||
Number,
|
||||
Object,
|
||||
Array,
|
||||
Boolean,
|
||||
Null,
|
||||
_UNKNOWN,
|
||||
}
|
||||
|
||||
impl Default for JsonType {
|
||||
fn default() -> Self {
|
||||
JsonType::_UNKNOWN
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for JsonType {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
JsonType::String => serializer.serialize_str("string"),
|
||||
JsonType::Number => serializer.serialize_str("number"),
|
||||
JsonType::Object => serializer.serialize_str("object"),
|
||||
JsonType::Array => serializer.serialize_str("array"),
|
||||
JsonType::Boolean => serializer.serialize_str("boolean"),
|
||||
JsonType::Null => serializer.serialize_str("null"),
|
||||
JsonType::_UNKNOWN => serializer.serialize_str("unknown"),
|
||||
}
|
||||
}
|
||||
}
|
||||
54
crates/yaak-grpc/src/lib.rs
Normal file
54
crates/yaak-grpc/src/lib.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use prost_reflect::{DynamicMessage, MethodDescriptor, SerializeOptions};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Deserializer;
|
||||
|
||||
mod any;
|
||||
mod client;
|
||||
mod codec;
|
||||
pub mod error;
|
||||
mod json_schema;
|
||||
pub mod manager;
|
||||
mod reflection;
|
||||
mod transport;
|
||||
|
||||
pub use tonic::Code;
|
||||
pub use tonic::metadata::*;
|
||||
|
||||
pub fn serialize_options() -> SerializeOptions {
|
||||
SerializeOptions::new().skip_default_fields(false)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
#[serde(default, rename_all = "camelCase")]
|
||||
pub struct ServiceDefinition {
|
||||
pub name: String,
|
||||
pub methods: Vec<MethodDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
#[serde(default, rename_all = "camelCase")]
|
||||
pub struct MethodDefinition {
|
||||
pub name: String,
|
||||
pub schema: String,
|
||||
pub client_streaming: bool,
|
||||
pub server_streaming: bool,
|
||||
}
|
||||
|
||||
static SERIALIZE_OPTIONS: &'static SerializeOptions =
|
||||
&SerializeOptions::new().skip_default_fields(false).stringify_64_bit_integers(false);
|
||||
|
||||
pub fn serialize_message(msg: &DynamicMessage) -> Result<String, String> {
|
||||
let mut buf = Vec::new();
|
||||
let mut se = serde_json::Serializer::pretty(&mut buf);
|
||||
msg.serialize_with_options(&mut se, SERIALIZE_OPTIONS).map_err(|e| e.to_string())?;
|
||||
let s = String::from_utf8(buf).expect("serde_json to emit valid utf8");
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
pub fn deserialize_message(msg: &str, method: MethodDescriptor) -> Result<DynamicMessage, String> {
|
||||
let mut deserializer = Deserializer::from_str(&msg);
|
||||
let req_message = DynamicMessage::deserialize(method.input(), &mut deserializer)
|
||||
.map_err(|e| e.to_string())?;
|
||||
deserializer.end().map_err(|e| e.to_string())?;
|
||||
Ok(req_message)
|
||||
}
|
||||
426
crates/yaak-grpc/src/manager.rs
Normal file
426
crates/yaak-grpc/src/manager.rs
Normal file
@@ -0,0 +1,426 @@
|
||||
use crate::codec::DynamicCodec;
|
||||
use crate::error::Error::GenericError;
|
||||
use crate::error::Result;
|
||||
use crate::reflection::{
|
||||
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path, reflect_types_for_message,
|
||||
};
|
||||
use crate::transport::get_transport;
|
||||
use crate::{MethodDefinition, ServiceDefinition, json_schema};
|
||||
use hyper_rustls::HttpsConnector;
|
||||
use hyper_util::client::legacy::Client;
|
||||
use hyper_util::client::legacy::connect::HttpConnector;
|
||||
use log::{info, warn};
|
||||
pub use prost_reflect::DynamicMessage;
|
||||
use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
|
||||
use serde_json::Deserializer;
|
||||
use std::collections::BTreeMap;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tonic::body::BoxBody;
|
||||
use tonic::metadata::{MetadataKey, MetadataValue};
|
||||
use tonic::transport::Uri;
|
||||
use tonic::{IntoRequest, IntoStreamingRequest, Request, Response, Status, Streaming};
|
||||
use yaak_tls::ClientCertificateConfig;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GrpcConnection {
|
||||
pool: Arc<RwLock<DescriptorPool>>,
|
||||
conn: Client<HttpsConnector<HttpConnector>, BoxBody>,
|
||||
pub uri: Uri,
|
||||
use_reflection: bool,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct GrpcStreamError {
|
||||
pub message: String,
|
||||
pub status: Option<Status>,
|
||||
}
|
||||
|
||||
impl Error for GrpcStreamError {}
|
||||
|
||||
impl Display for GrpcStreamError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.status {
|
||||
Some(status) => write!(f, "[{}] {}", status, self.message),
|
||||
None => write!(f, "{}", self.message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for GrpcStreamError {
|
||||
fn from(value: String) -> Self {
|
||||
GrpcStreamError { message: value.to_string(), status: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Status> for GrpcStreamError {
|
||||
fn from(s: Status) -> Self {
|
||||
GrpcStreamError { message: s.message().to_string(), status: Some(s) }
|
||||
}
|
||||
}
|
||||
|
||||
impl GrpcConnection {
|
||||
pub async fn method(&self, service: &str, method: &str) -> Result<MethodDescriptor> {
|
||||
let service = self.service(service).await?;
|
||||
let method = service
|
||||
.methods()
|
||||
.find(|m| m.name() == method)
|
||||
.ok_or(GenericError("Failed to find method".to_string()))?;
|
||||
Ok(method)
|
||||
}
|
||||
|
||||
async fn service(&self, service: &str) -> Result<ServiceDescriptor> {
|
||||
let pool = self.pool.read().await;
|
||||
let service = pool
|
||||
.get_service_by_name(service)
|
||||
.ok_or(GenericError("Failed to find service".to_string()))?;
|
||||
Ok(service)
|
||||
}
|
||||
|
||||
pub async fn unary(
|
||||
&self,
|
||||
service: &str,
|
||||
method: &str,
|
||||
message: &str,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<Response<DynamicMessage>> {
|
||||
if self.use_reflection {
|
||||
reflect_types_for_message(self.pool.clone(), &self.uri, message, metadata, client_cert)
|
||||
.await?;
|
||||
}
|
||||
let method = &self.method(&service, &method).await?;
|
||||
let input_message = method.input();
|
||||
|
||||
let mut deserializer = Deserializer::from_str(message);
|
||||
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)?;
|
||||
deserializer.end()?;
|
||||
|
||||
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
|
||||
|
||||
let mut req = req_message.into_request();
|
||||
decorate_req(metadata, &mut req)?;
|
||||
|
||||
let path = method_desc_to_path(method);
|
||||
let codec = DynamicCodec::new(method.clone());
|
||||
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
|
||||
|
||||
Ok(client.unary(req, path, codec).await?)
|
||||
}
|
||||
|
||||
pub async fn streaming(
|
||||
&self,
|
||||
service: &str,
|
||||
method: &str,
|
||||
stream: ReceiverStream<String>,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<Response<Streaming<DynamicMessage>>> {
|
||||
let method = &self.method(&service, &method).await?;
|
||||
let mapped_stream = {
|
||||
let input_message = method.input();
|
||||
let pool = self.pool.clone();
|
||||
let uri = self.uri.clone();
|
||||
let md = metadata.clone();
|
||||
let use_reflection = self.use_reflection.clone();
|
||||
let client_cert = client_cert.clone();
|
||||
stream.filter_map(move |json| {
|
||||
let pool = pool.clone();
|
||||
let uri = uri.clone();
|
||||
let input_message = input_message.clone();
|
||||
let md = md.clone();
|
||||
let use_reflection = use_reflection.clone();
|
||||
let client_cert = client_cert.clone();
|
||||
tokio::runtime::Handle::current().block_on(async move {
|
||||
if use_reflection {
|
||||
if let Err(e) =
|
||||
reflect_types_for_message(pool, &uri, &json, &md, client_cert).await
|
||||
{
|
||||
warn!("Failed to resolve Any types: {e}");
|
||||
}
|
||||
}
|
||||
let mut de = Deserializer::from_str(&json);
|
||||
match DynamicMessage::deserialize(input_message, &mut de) {
|
||||
Ok(m) => Some(m),
|
||||
Err(e) => {
|
||||
warn!("Failed to deserialize message: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
|
||||
let path = method_desc_to_path(method);
|
||||
let codec = DynamicCodec::new(method.clone());
|
||||
|
||||
let mut req = mapped_stream.into_streaming_request();
|
||||
decorate_req(metadata, &mut req)?;
|
||||
|
||||
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
|
||||
Ok(client.streaming(req, path, codec).await?)
|
||||
}
|
||||
|
||||
pub async fn client_streaming(
|
||||
&self,
|
||||
service: &str,
|
||||
method: &str,
|
||||
stream: ReceiverStream<String>,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<Response<DynamicMessage>> {
|
||||
let method = &self.method(&service, &method).await?;
|
||||
let mapped_stream = {
|
||||
let input_message = method.input();
|
||||
let pool = self.pool.clone();
|
||||
let uri = self.uri.clone();
|
||||
let md = metadata.clone();
|
||||
let use_reflection = self.use_reflection.clone();
|
||||
let client_cert = client_cert.clone();
|
||||
stream.filter_map(move |json| {
|
||||
let pool = pool.clone();
|
||||
let uri = uri.clone();
|
||||
let input_message = input_message.clone();
|
||||
let md = md.clone();
|
||||
let use_reflection = use_reflection.clone();
|
||||
let client_cert = client_cert.clone();
|
||||
tokio::runtime::Handle::current().block_on(async move {
|
||||
if use_reflection {
|
||||
if let Err(e) =
|
||||
reflect_types_for_message(pool, &uri, &json, &md, client_cert).await
|
||||
{
|
||||
warn!("Failed to resolve Any types: {e}");
|
||||
}
|
||||
}
|
||||
let mut de = Deserializer::from_str(&json);
|
||||
match DynamicMessage::deserialize(input_message, &mut de) {
|
||||
Ok(m) => Some(m),
|
||||
Err(e) => {
|
||||
warn!("Failed to deserialize message: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
};
|
||||
|
||||
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
|
||||
let path = method_desc_to_path(method);
|
||||
let codec = DynamicCodec::new(method.clone());
|
||||
|
||||
let mut req = mapped_stream.into_streaming_request();
|
||||
decorate_req(metadata, &mut req)?;
|
||||
|
||||
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
|
||||
Ok(client
|
||||
.client_streaming(req, path, codec)
|
||||
.await
|
||||
.map_err(|e| GrpcStreamError { message: e.message().to_string(), status: Some(e) })?)
|
||||
}
|
||||
|
||||
pub async fn server_streaming(
|
||||
&self,
|
||||
service: &str,
|
||||
method: &str,
|
||||
message: &str,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
) -> Result<Response<Streaming<DynamicMessage>>> {
|
||||
let method = &self.method(&service, &method).await?;
|
||||
let input_message = method.input();
|
||||
|
||||
let mut deserializer = Deserializer::from_str(message);
|
||||
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)?;
|
||||
deserializer.end()?;
|
||||
|
||||
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
|
||||
|
||||
let mut req = req_message.into_request();
|
||||
decorate_req(metadata, &mut req)?;
|
||||
|
||||
let path = method_desc_to_path(method);
|
||||
let codec = DynamicCodec::new(method.clone());
|
||||
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
|
||||
Ok(client.server_streaming(req, path, codec).await?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for GrpcHandle to compile proto files
|
||||
#[derive(Clone)]
|
||||
pub struct GrpcConfig {
|
||||
/// Path to the protoc include directory (vendored/protoc/include)
|
||||
pub protoc_include_dir: PathBuf,
|
||||
/// Path to the yaakprotoc sidecar binary
|
||||
pub protoc_bin_path: PathBuf,
|
||||
}
|
||||
|
||||
pub struct GrpcHandle {
|
||||
config: GrpcConfig,
|
||||
pools: BTreeMap<String, DescriptorPool>,
|
||||
}
|
||||
|
||||
impl GrpcHandle {
|
||||
pub fn new(config: GrpcConfig) -> Self {
|
||||
let pools = BTreeMap::new();
|
||||
Self { pools, config }
|
||||
}
|
||||
}
|
||||
|
||||
impl GrpcHandle {
|
||||
/// Remove cached descriptor pool for the given key, if present.
|
||||
pub fn invalidate_pool(&mut self, id: &str, uri: &str, proto_files: &Vec<PathBuf>) {
|
||||
let key = make_pool_key(id, uri, proto_files);
|
||||
self.pools.remove(&key);
|
||||
}
|
||||
|
||||
pub async fn reflect(
|
||||
&mut self,
|
||||
id: &str,
|
||||
uri: &str,
|
||||
proto_files: &Vec<PathBuf>,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<bool> {
|
||||
let server_reflection = proto_files.is_empty();
|
||||
let key = make_pool_key(id, uri, proto_files);
|
||||
|
||||
// If we already have a pool for this key, reuse it and avoid re-reflection
|
||||
if self.pools.contains_key(&key) {
|
||||
return Ok(server_reflection);
|
||||
}
|
||||
|
||||
let pool = if server_reflection {
|
||||
let full_uri = uri_from_str(uri)?;
|
||||
fill_pool_from_reflection(&full_uri, metadata, validate_certificates, client_cert).await
|
||||
} else {
|
||||
fill_pool_from_files(&self.config, proto_files).await
|
||||
}?;
|
||||
|
||||
self.pools.insert(key, pool.clone());
|
||||
Ok(server_reflection)
|
||||
}
|
||||
|
||||
pub async fn services(
|
||||
&mut self,
|
||||
id: &str,
|
||||
uri: &str,
|
||||
proto_files: &Vec<PathBuf>,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
skip_cache: bool,
|
||||
) -> Result<Vec<ServiceDefinition>> {
|
||||
// Ensure we have a pool; reflect only if missing
|
||||
if skip_cache || self.get_pool(id, uri, proto_files).is_none() {
|
||||
info!("Reflecting gRPC services for {} at {}", id, uri);
|
||||
self.reflect(id, uri, proto_files, metadata, validate_certificates, client_cert)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let pool = self
|
||||
.get_pool(id, uri, proto_files)
|
||||
.ok_or(GenericError("Failed to get pool".to_string()))?;
|
||||
Ok(self.services_from_pool(&pool))
|
||||
}
|
||||
|
||||
fn services_from_pool(&self, pool: &DescriptorPool) -> Vec<ServiceDefinition> {
|
||||
pool.services()
|
||||
.map(|s| {
|
||||
let mut def =
|
||||
ServiceDefinition { name: s.full_name().to_string(), methods: vec![] };
|
||||
for method in s.methods() {
|
||||
let input_message = method.input();
|
||||
def.methods.push(MethodDefinition {
|
||||
name: method.name().to_string(),
|
||||
server_streaming: method.is_server_streaming(),
|
||||
client_streaming: method.is_client_streaming(),
|
||||
schema: serde_json::to_string_pretty(&json_schema::message_to_json_schema(
|
||||
&pool,
|
||||
input_message,
|
||||
))
|
||||
.expect("Failed to serialize JSON schema"),
|
||||
})
|
||||
}
|
||||
def
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
&mut self,
|
||||
id: &str,
|
||||
uri: &str,
|
||||
proto_files: &Vec<PathBuf>,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<GrpcConnection> {
|
||||
let use_reflection = proto_files.is_empty();
|
||||
if self.get_pool(id, uri, proto_files).is_none() {
|
||||
self.reflect(
|
||||
id,
|
||||
uri,
|
||||
proto_files,
|
||||
metadata,
|
||||
validate_certificates,
|
||||
client_cert.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
let pool = self
|
||||
.get_pool(id, uri, proto_files)
|
||||
.ok_or(GenericError("Failed to get pool".to_string()))?
|
||||
.clone();
|
||||
let uri = uri_from_str(uri)?;
|
||||
let conn = get_transport(validate_certificates, client_cert.clone())?;
|
||||
Ok(GrpcConnection { pool: Arc::new(RwLock::new(pool)), use_reflection, conn, uri })
|
||||
}
|
||||
|
||||
fn get_pool(&self, id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> Option<&DescriptorPool> {
|
||||
self.pools.get(make_pool_key(id, uri, proto_files).as_str())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn decorate_req<T>(
|
||||
metadata: &BTreeMap<String, String>,
|
||||
req: &mut Request<T>,
|
||||
) -> Result<()> {
|
||||
for (k, v) in metadata {
|
||||
req.metadata_mut()
|
||||
.insert(MetadataKey::from_str(k.as_str())?, MetadataValue::from_str(v.as_str())?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn uri_from_str(uri_str: &str) -> Result<Uri> {
|
||||
match Uri::from_str(uri_str) {
|
||||
Ok(uri) => Ok(uri),
|
||||
Err(err) => {
|
||||
// Uri::from_str basically only returns "invalid format" so we add more context here
|
||||
Err(GenericError(format!("Failed to parse URL, {}", err.to_string())))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_pool_key(id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> String {
|
||||
let pool_key = format!(
|
||||
"{}::{}::{}",
|
||||
id,
|
||||
uri,
|
||||
proto_files
|
||||
.iter()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(":")
|
||||
);
|
||||
|
||||
format!("{:x}", md5::compute(pool_key))
|
||||
}
|
||||
452
crates/yaak-grpc/src/reflection.rs
Normal file
452
crates/yaak-grpc/src/reflection.rs
Normal file
@@ -0,0 +1,452 @@
|
||||
use crate::any::collect_any_types;
|
||||
use crate::client::AutoReflectionClient;
|
||||
use crate::error::Error::GenericError;
|
||||
use crate::error::Result;
|
||||
use crate::manager::GrpcConfig;
|
||||
use anyhow::anyhow;
|
||||
use async_recursion::async_recursion;
|
||||
use log::{debug, info, warn};
|
||||
use prost::Message;
|
||||
use prost_reflect::{DescriptorPool, MethodDescriptor};
|
||||
use prost_types::{FileDescriptorProto, FileDescriptorSet};
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::env::temp_dir;
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::fs;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::RwLock;
|
||||
use tonic::codegen::http::uri::PathAndQuery;
|
||||
use tonic::transport::Uri;
|
||||
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
|
||||
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
|
||||
use yaak_tls::ClientCertificateConfig;
|
||||
|
||||
pub async fn fill_pool_from_files(
|
||||
config: &GrpcConfig,
|
||||
paths: &Vec<PathBuf>,
|
||||
) -> Result<DescriptorPool> {
|
||||
let mut pool = DescriptorPool::new();
|
||||
let random_file_name = format!("{}.desc", uuid::Uuid::new_v4());
|
||||
let desc_path = temp_dir().join(random_file_name);
|
||||
|
||||
// HACK: Remove UNC prefix for Windows paths
|
||||
let global_import_dir =
|
||||
dunce::simplified(config.protoc_include_dir.as_path()).to_string_lossy().to_string();
|
||||
let desc_path = dunce::simplified(desc_path.as_path());
|
||||
|
||||
let mut args = vec![
|
||||
"--include_imports".to_string(),
|
||||
"--include_source_info".to_string(),
|
||||
"-I".to_string(),
|
||||
global_import_dir,
|
||||
"-o".to_string(),
|
||||
desc_path.to_string_lossy().to_string(),
|
||||
];
|
||||
|
||||
let mut include_dirs = HashSet::new();
|
||||
let mut include_protos = HashSet::new();
|
||||
|
||||
for p in paths {
|
||||
if !p.exists() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Dirs are added as includes
|
||||
if p.is_dir() {
|
||||
include_dirs.insert(p.to_string_lossy().to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
let parent = p.as_path().parent();
|
||||
if let Some(parent_path) = parent {
|
||||
match find_parent_proto_dir(parent_path) {
|
||||
None => {
|
||||
// Add parent/grandparent as fallback
|
||||
include_dirs.insert(parent_path.to_string_lossy().to_string());
|
||||
if let Some(grandparent_path) = parent_path.parent() {
|
||||
include_dirs.insert(grandparent_path.to_string_lossy().to_string());
|
||||
}
|
||||
}
|
||||
Some(p) => {
|
||||
include_dirs.insert(p.to_string_lossy().to_string());
|
||||
}
|
||||
};
|
||||
} else {
|
||||
debug!("ignoring {:?} since it does not exist.", parent)
|
||||
}
|
||||
|
||||
include_protos.insert(p.to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
for d in include_dirs.clone() {
|
||||
args.push("-I".to_string());
|
||||
args.push(d);
|
||||
}
|
||||
for p in include_protos.clone() {
|
||||
args.push(p);
|
||||
}
|
||||
|
||||
info!("Invoking protoc with {}", args.join(" "));
|
||||
|
||||
let out = Command::new(&config.protoc_bin_path)
|
||||
.args(&args)
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| GenericError(format!("Failed to run protoc: {}", e)))?;
|
||||
|
||||
if !out.status.success() {
|
||||
return Err(GenericError(format!(
|
||||
"protoc failed with status {}: {}",
|
||||
out.status.code().unwrap_or(-1),
|
||||
String::from_utf8_lossy(out.stderr.as_slice())
|
||||
)));
|
||||
}
|
||||
|
||||
let bytes = fs::read(desc_path).await?;
|
||||
let fdp = FileDescriptorSet::decode(bytes.deref())?;
|
||||
pool.add_file_descriptor_set(fdp)?;
|
||||
|
||||
fs::remove_file(desc_path).await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
pub async fn fill_pool_from_reflection(
|
||||
uri: &Uri,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<DescriptorPool> {
|
||||
let mut pool = DescriptorPool::new();
|
||||
let mut client = AutoReflectionClient::new(uri, validate_certificates, client_cert)?;
|
||||
|
||||
for service in list_services(&mut client, metadata).await? {
|
||||
if service == "grpc.reflection.v1alpha.ServerReflection" {
|
||||
continue;
|
||||
}
|
||||
if service == "grpc.reflection.v1.ServerReflection" {
|
||||
continue;
|
||||
}
|
||||
debug!("Fetching descriptors for {}", service);
|
||||
file_descriptor_set_from_service_name(&service, &mut pool, &mut client, metadata).await;
|
||||
}
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn list_services(
|
||||
client: &mut AutoReflectionClient,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
) -> Result<Vec<String>> {
|
||||
let response =
|
||||
client.send_reflection_request(MessageRequest::ListServices("".into()), metadata).await?;
|
||||
|
||||
let list_services_response = match response {
|
||||
MessageResponse::ListServicesResponse(resp) => resp,
|
||||
_ => panic!("Expected a ListServicesResponse variant"),
|
||||
};
|
||||
|
||||
Ok(list_services_response.service.iter().map(|s| s.name.clone()).collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
async fn file_descriptor_set_from_service_name(
|
||||
service_name: &str,
|
||||
pool: &mut DescriptorPool,
|
||||
client: &mut AutoReflectionClient,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
) {
|
||||
let response = match client
|
||||
.send_reflection_request(
|
||||
MessageRequest::FileContainingSymbol(service_name.into()),
|
||||
metadata,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
warn!("Error fetching file descriptor for service {}: {:?}", service_name, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let file_descriptor_response = match response {
|
||||
MessageResponse::FileDescriptorResponse(resp) => resp,
|
||||
_ => panic!("Expected a FileDescriptorResponse variant"),
|
||||
};
|
||||
|
||||
add_file_descriptors_to_pool(
|
||||
file_descriptor_response.file_descriptor_proto,
|
||||
pool,
|
||||
client,
|
||||
metadata,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn reflect_types_for_message(
|
||||
pool: Arc<RwLock<DescriptorPool>>,
|
||||
uri: &Uri,
|
||||
json: &str,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<()> {
|
||||
// 1. Collect all Any types in the JSON
|
||||
let mut extra_types = Vec::new();
|
||||
collect_any_types(json, &mut extra_types);
|
||||
|
||||
if extra_types.is_empty() {
|
||||
return Ok(()); // nothing to do
|
||||
}
|
||||
|
||||
let mut client = AutoReflectionClient::new(uri, false, client_cert)?;
|
||||
for extra_type in extra_types {
|
||||
{
|
||||
let guard = pool.read().await;
|
||||
if guard.get_message_by_name(&extra_type).is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
info!("Adding file descriptor for {:?} from reflection", extra_type);
|
||||
let req = MessageRequest::FileContainingSymbol(extra_type.clone().into());
|
||||
let resp = match client.send_reflection_request(req, metadata).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
return Err(GenericError(format!(
|
||||
"Error sending reflection request for @type \"{extra_type}\": {e:?}",
|
||||
)));
|
||||
}
|
||||
};
|
||||
let files = match resp {
|
||||
MessageResponse::FileDescriptorResponse(resp) => resp.file_descriptor_proto,
|
||||
_ => panic!("Expected a FileDescriptorResponse variant"),
|
||||
};
|
||||
|
||||
{
|
||||
let mut guard = pool.write().await;
|
||||
add_file_descriptors_to_pool(files, &mut *guard, &mut client, metadata).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_recursion]
|
||||
pub(crate) async fn add_file_descriptors_to_pool(
|
||||
fds: Vec<Vec<u8>>,
|
||||
pool: &mut DescriptorPool,
|
||||
client: &mut AutoReflectionClient,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
) {
|
||||
let mut topo_sort = topology::SimpleTopoSort::new();
|
||||
let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len());
|
||||
|
||||
for fd in fds {
|
||||
let fdp = FileDescriptorProto::decode(fd.deref()).unwrap();
|
||||
|
||||
topo_sort.insert(fdp.name().to_string(), fdp.dependency.clone());
|
||||
fd_mapping.insert(fdp.name().to_string(), fdp);
|
||||
}
|
||||
|
||||
for node in topo_sort {
|
||||
match node {
|
||||
Ok(node) => {
|
||||
if let Some(fdp) = fd_mapping.remove(&node) {
|
||||
pool.add_file_descriptor_proto(fdp).expect("add file descriptor proto");
|
||||
} else {
|
||||
file_descriptor_set_by_filename(node.as_str(), pool, client, metadata).await;
|
||||
}
|
||||
}
|
||||
Err(_) => panic!("proto file got cycle!"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn file_descriptor_set_by_filename(
|
||||
filename: &str,
|
||||
pool: &mut DescriptorPool,
|
||||
client: &mut AutoReflectionClient,
|
||||
metadata: &BTreeMap<String, String>,
|
||||
) {
|
||||
// We already fetched this file
|
||||
if let Some(_) = pool.get_file_by_name(filename) {
|
||||
return;
|
||||
}
|
||||
|
||||
let msg = MessageRequest::FileByFilename(filename.into());
|
||||
let response = client.send_reflection_request(msg, metadata).await;
|
||||
let file_descriptor_response = match response {
|
||||
Ok(MessageResponse::FileDescriptorResponse(resp)) => resp,
|
||||
Ok(_) => {
|
||||
panic!("Expected a FileDescriptorResponse variant")
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error fetching file descriptor for {}: {:?}", filename, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
add_file_descriptors_to_pool(
|
||||
file_descriptor_response.file_descriptor_proto,
|
||||
pool,
|
||||
client,
|
||||
metadata,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery {
|
||||
let full_name = md.full_name();
|
||||
let (namespace, method_name) = full_name
|
||||
.rsplit_once('.')
|
||||
.ok_or_else(|| anyhow!("invalid method path"))
|
||||
.expect("invalid method path");
|
||||
PathAndQuery::from_str(&format!("/{}/{}", namespace, method_name)).expect("invalid method path")
|
||||
}
|
||||
|
||||
mod topology {
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
pub struct SimpleTopoSort<T> {
|
||||
out_graph: HashMap<T, HashSet<T>>,
|
||||
in_graph: HashMap<T, HashSet<T>>,
|
||||
}
|
||||
|
||||
impl<T> SimpleTopoSort<T>
|
||||
where
|
||||
T: Eq + std::hash::Hash + Clone,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
SimpleTopoSort { out_graph: HashMap::new(), in_graph: HashMap::new() }
|
||||
}
|
||||
|
||||
pub fn insert<I: IntoIterator<Item = T>>(&mut self, node: T, deps: I) {
|
||||
self.out_graph.entry(node.clone()).or_insert(HashSet::new());
|
||||
for dep in deps {
|
||||
self.out_graph.entry(node.clone()).or_insert(HashSet::new()).insert(dep.clone());
|
||||
self.in_graph.entry(dep.clone()).or_insert(HashSet::new()).insert(node.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> IntoIterator for SimpleTopoSort<T>
|
||||
where
|
||||
T: Eq + std::hash::Hash + Clone,
|
||||
{
|
||||
type Item = <SimpleTopoSortIter<T> as Iterator>::Item;
|
||||
type IntoIter = SimpleTopoSortIter<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SimpleTopoSortIter::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SimpleTopoSortIter<T> {
|
||||
data: SimpleTopoSort<T>,
|
||||
zero_indegree: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> SimpleTopoSortIter<T>
|
||||
where
|
||||
T: Eq + std::hash::Hash + Clone,
|
||||
{
|
||||
pub fn new(data: SimpleTopoSort<T>) -> Self {
|
||||
let mut zero_indegree = Vec::new();
|
||||
for (node, _) in data.in_graph.iter() {
|
||||
if !data.out_graph.contains_key(node) {
|
||||
zero_indegree.push(node.clone());
|
||||
}
|
||||
}
|
||||
for (node, deps) in data.out_graph.iter() {
|
||||
if deps.is_empty() {
|
||||
zero_indegree.push(node.clone());
|
||||
}
|
||||
}
|
||||
|
||||
SimpleTopoSortIter { data, zero_indegree }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Iterator for SimpleTopoSortIter<T>
|
||||
where
|
||||
T: Eq + std::hash::Hash + Clone,
|
||||
{
|
||||
type Item = Result<T, &'static str>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.zero_indegree.is_empty() {
|
||||
if self.data.out_graph.is_empty() {
|
||||
return None;
|
||||
}
|
||||
return Some(Err("Cycle detected"));
|
||||
}
|
||||
|
||||
let node = self.zero_indegree.pop().unwrap();
|
||||
if let Some(parents) = self.data.in_graph.get(&node) {
|
||||
for parent in parents.iter() {
|
||||
let deps = self.data.out_graph.get_mut(parent).unwrap();
|
||||
deps.remove(&node);
|
||||
if deps.is_empty() {
|
||||
self.zero_indegree.push(parent.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
self.data.out_graph.remove(&node);
|
||||
|
||||
Some(Ok(node))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort() {
|
||||
{
|
||||
let mut topo_sort = SimpleTopoSort::new();
|
||||
topo_sort.insert("a", []);
|
||||
|
||||
for node in topo_sort {
|
||||
match node {
|
||||
Ok(n) => assert_eq!(n, "a"),
|
||||
Err(e) => panic!("err {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut topo_sort = SimpleTopoSort::new();
|
||||
topo_sort.insert("a", ["b"]);
|
||||
topo_sort.insert("b", []);
|
||||
|
||||
let mut iter = topo_sort.into_iter();
|
||||
match iter.next() {
|
||||
Some(Ok(n)) => assert_eq!(n, "b"),
|
||||
_ => panic!("err"),
|
||||
}
|
||||
match iter.next() {
|
||||
Some(Ok(n)) => assert_eq!(n, "a"),
|
||||
_ => panic!("err"),
|
||||
}
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_parent_proto_dir(start_path: impl AsRef<Path>) -> Option<PathBuf> {
|
||||
let mut dir = start_path.as_ref().canonicalize().ok()?;
|
||||
|
||||
loop {
|
||||
if let Some(name) = dir.file_name().and_then(|n| n.to_str()) {
|
||||
if name == "proto" {
|
||||
return Some(dir);
|
||||
}
|
||||
}
|
||||
|
||||
let parent = dir.parent()?;
|
||||
if parent == dir {
|
||||
return None; // Reached root
|
||||
}
|
||||
|
||||
dir = parent.to_path_buf();
|
||||
}
|
||||
}
|
||||
40
crates/yaak-grpc/src/transport.rs
Normal file
40
crates/yaak-grpc/src/transport.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use crate::error::Result;
|
||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||
use hyper_util::client::legacy::Client;
|
||||
use hyper_util::client::legacy::connect::HttpConnector;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use log::info;
|
||||
use tonic::body::BoxBody;
|
||||
use yaak_tls::{ClientCertificateConfig, get_tls_config};
|
||||
|
||||
// I think ALPN breaks this because we're specifying http2_only
|
||||
const WITH_ALPN: bool = false;
|
||||
|
||||
pub(crate) fn get_transport(
|
||||
validate_certificates: bool,
|
||||
client_cert: Option<ClientCertificateConfig>,
|
||||
) -> Result<Client<HttpsConnector<HttpConnector>, BoxBody>> {
|
||||
let tls_config = get_tls_config(validate_certificates, WITH_ALPN, client_cert.clone())?;
|
||||
|
||||
let mut http = HttpConnector::new();
|
||||
http.enforce_http(false);
|
||||
|
||||
let connector = HttpsConnectorBuilder::new()
|
||||
.with_tls_config(tls_config)
|
||||
.https_or_http()
|
||||
.enable_http2()
|
||||
.build();
|
||||
|
||||
let client = Client::builder(TokioExecutor::new())
|
||||
.pool_max_idle_per_host(0)
|
||||
.http2_only(true)
|
||||
.build(connector);
|
||||
|
||||
info!(
|
||||
"Created gRPC client validate_certs={} client_cert={}",
|
||||
validate_certificates,
|
||||
client_cert.is_some()
|
||||
);
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
Reference in New Issue
Block a user