Decouple core Yaak logic from Tauri (#354)

This commit is contained in:
Gregory Schier
2026-01-08 20:44:25 -08:00
committed by GitHub
parent 3bcc0b8356
commit ef80216ca1
465 changed files with 3052 additions and 6234 deletions

View File

@@ -0,0 +1,60 @@
use log::error;
pub(crate) fn collect_any_types(json: &str, out: &mut Vec<String>) {
let value = match serde_json::from_str(json).map_err(|e| e.to_string()) {
Ok(v) => v,
Err(e) => {
error!("Failed to parse gRPC message JSON: {e:?}");
return;
}
};
collect_any_types_value(&value, out);
}
fn collect_any_types_value(json: &serde_json::Value, out: &mut Vec<String>) {
match json {
serde_json::Value::Object(map) => {
if let Some(t) = map.get("@type").and_then(|v| v.as_str()) {
if let Some(full_name) = t.rsplit_once('/').map(|(_, n)| n) {
out.push(full_name.to_string());
}
}
for v in map.values() {
collect_any_types_value(v, out);
}
}
serde_json::Value::Array(arr) => {
for v in arr {
collect_any_types_value(v, out);
}
}
_ => {}
}
}
// Write tests for this
#[cfg(test)]
mod tests {
#[test]
fn test_collect_any_types() {
let json = r#"{
"mounts": [
{
"mountSource": {
"@type": "type.googleapis.com/mount_source.MountSourceRBDVolume",
"volumeID": "volumes/rbd"
}
}
],
"foo": {
"@type": "type.googleapis.com/foo.bar",
"foo": "fooo"
}
}"#;
let mut out = Vec::new();
super::collect_any_types(json, &mut out);
assert_eq!(out, vec!["foo.bar", "mount_source.MountSourceRBDVolume"]);
}
}

View File

@@ -0,0 +1,182 @@
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::manager::decorate_req;
use crate::transport::get_transport;
use async_recursion::async_recursion;
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use log::debug;
use std::collections::BTreeMap;
use tokio_stream::StreamExt;
use tonic::Request;
use tonic::body::BoxBody;
use tonic::transport::Uri;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
use tonic_reflection::pb::v1::{
ErrorResponse, ExtensionNumberResponse, ListServiceResponse, ServerReflectionRequest,
ServiceResponse,
};
use tonic_reflection::pb::v1::{ExtensionRequest, FileDescriptorResponse};
use tonic_reflection::pb::{v1, v1alpha};
use yaak_tls::ClientCertificateConfig;
pub struct AutoReflectionClient<T = Client<HttpsConnector<HttpConnector>, BoxBody>> {
use_v1alpha: bool,
client_v1: v1::server_reflection_client::ServerReflectionClient<T>,
client_v1alpha: v1alpha::server_reflection_client::ServerReflectionClient<T>,
}
impl AutoReflectionClient {
pub fn new(
uri: &Uri,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Self> {
let client_v1 = v1::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(validate_certificates, client_cert.clone())?,
uri.clone(),
);
let client_v1alpha = v1alpha::server_reflection_client::ServerReflectionClient::with_origin(
get_transport(validate_certificates, client_cert.clone())?,
uri.clone(),
);
Ok(AutoReflectionClient { use_v1alpha: false, client_v1, client_v1alpha })
}
#[async_recursion]
pub async fn send_reflection_request(
&mut self,
message: MessageRequest,
metadata: &BTreeMap<String, String>,
) -> Result<MessageResponse> {
let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter
message_request: Some(message.clone()),
};
if self.use_v1alpha {
let mut request =
Request::new(tokio_stream::once(to_v1alpha_request(reflection_request)));
decorate_req(metadata, &mut request)?;
self.client_v1alpha
.server_reflection_info(request)
.await
.map_err(|e| match e.code() {
tonic::Code::Unavailable => {
GenericError("Failed to connect to endpoint".to_string())
}
tonic::Code::Unauthenticated => {
GenericError("Authentication failed".to_string())
}
tonic::Code::DeadlineExceeded => GenericError("Deadline exceeded".to_string()),
_ => GenericError(e.to_string()),
})?
.into_inner()
.next()
.await
.ok_or(GenericError("Missing reflection message".to_string()))??
.message_response
.ok_or(GenericError("No reflection response".to_string()))
.map(|resp| to_v1_msg_response(resp))
} else {
let mut request = Request::new(tokio_stream::once(reflection_request));
decorate_req(metadata, &mut request)?;
let resp = self.client_v1.server_reflection_info(request).await;
match resp {
Ok(r) => Ok(r),
Err(e) => match e.code().clone() {
tonic::Code::Unimplemented => {
// If v1 fails, change to v1alpha and try again
debug!("gRPC schema reflection falling back to v1alpha");
self.use_v1alpha = true;
return self.send_reflection_request(message, metadata).await;
}
_ => Err(e),
},
}
.map_err(|e| match e.code() {
tonic::Code::Unavailable => {
GenericError("Failed to connect to endpoint".to_string())
}
tonic::Code::Unauthenticated => GenericError("Authentication failed".to_string()),
tonic::Code::DeadlineExceeded => GenericError("Deadline exceeded".to_string()),
_ => GenericError(e.to_string()),
})?
.into_inner()
.next()
.await
.ok_or(GenericError("Missing reflection message".to_string()))??
.message_response
.ok_or(GenericError("No reflection response".to_string()))
}
}
}
fn to_v1_msg_response(
response: v1alpha::server_reflection_response::MessageResponse,
) -> MessageResponse {
match response {
v1alpha::server_reflection_response::MessageResponse::FileDescriptorResponse(v) => {
MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
file_descriptor_proto: v.file_descriptor_proto,
})
}
v1alpha::server_reflection_response::MessageResponse::AllExtensionNumbersResponse(v) => {
MessageResponse::AllExtensionNumbersResponse(ExtensionNumberResponse {
extension_number: v.extension_number,
base_type_name: v.base_type_name,
})
}
v1alpha::server_reflection_response::MessageResponse::ListServicesResponse(v) => {
MessageResponse::ListServicesResponse(ListServiceResponse {
service: v
.service
.iter()
.map(|s| ServiceResponse { name: s.name.clone() })
.collect(),
})
}
v1alpha::server_reflection_response::MessageResponse::ErrorResponse(v) => {
MessageResponse::ErrorResponse(ErrorResponse {
error_code: v.error_code,
error_message: v.error_message,
})
}
}
}
fn to_v1alpha_request(request: ServerReflectionRequest) -> v1alpha::ServerReflectionRequest {
v1alpha::ServerReflectionRequest {
host: request.host,
message_request: request.message_request.map(|m| to_v1alpha_msg_request(m)),
}
}
fn to_v1alpha_msg_request(
message: MessageRequest,
) -> v1alpha::server_reflection_request::MessageRequest {
match message {
MessageRequest::FileByFilename(v) => {
v1alpha::server_reflection_request::MessageRequest::FileByFilename(v)
}
MessageRequest::FileContainingSymbol(v) => {
v1alpha::server_reflection_request::MessageRequest::FileContainingSymbol(v)
}
MessageRequest::FileContainingExtension(ExtensionRequest {
extension_number,
containing_type,
}) => v1alpha::server_reflection_request::MessageRequest::FileContainingExtension(
v1alpha::ExtensionRequest { extension_number, containing_type },
),
MessageRequest::AllExtensionNumbersOfType(v) => {
v1alpha::server_reflection_request::MessageRequest::AllExtensionNumbersOfType(v)
}
MessageRequest::ListServices(v) => {
v1alpha::server_reflection_request::MessageRequest::ListServices(v)
}
}
}

View File

@@ -0,0 +1,50 @@
use prost_reflect::prost::Message;
use prost_reflect::{DynamicMessage, MethodDescriptor};
use tonic::Status;
use tonic::codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
#[derive(Clone)]
pub struct DynamicCodec(MethodDescriptor);
impl DynamicCodec {
#[allow(dead_code)]
pub fn new(md: MethodDescriptor) -> Self {
Self(md)
}
}
impl Codec for DynamicCodec {
type Encode = DynamicMessage;
type Decode = DynamicMessage;
type Encoder = Self;
type Decoder = Self;
fn encoder(&mut self) -> Self::Encoder {
self.clone()
}
fn decoder(&mut self) -> Self::Decoder {
self.clone()
}
}
impl Encoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
item.encode(dst).expect("buffer is too small to decode this message");
Ok(())
}
}
impl Decoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
let mut msg = DynamicMessage::new(self.0.output());
msg.merge(src).map_err(|err| Status::internal(err.to_string()))?;
Ok(Some(msg))
}
}

View File

@@ -0,0 +1,51 @@
use crate::manager::GrpcStreamError;
use prost::DecodeError;
use serde::{Serialize, Serializer};
use serde_json::Error as SerdeJsonError;
use std::io;
use thiserror::Error;
use tonic::Status;
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
TlsError(#[from] yaak_tls::error::Error),
#[error(transparent)]
TonicError(#[from] Status),
#[error("Prost reflect error: {0:?}")]
ProstReflectError(#[from] prost_reflect::DescriptorError),
#[error(transparent)]
DeserializerError(#[from] SerdeJsonError),
#[error(transparent)]
GrpcStreamError(#[from] GrpcStreamError),
#[error(transparent)]
GrpcDecodeError(#[from] DecodeError),
#[error(transparent)]
GrpcInvalidMetadataKeyError(#[from] tonic::metadata::errors::InvalidMetadataKey),
#[error(transparent)]
GrpcInvalidMetadataValueError(#[from] tonic::metadata::errors::InvalidMetadataValue),
#[error(transparent)]
IOError(#[from] io::Error),
#[error("GRPC error: {0}")]
GenericError(String),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,382 @@
use prost_reflect::{DescriptorPool, FieldDescriptor, MessageDescriptor};
use std::collections::{HashMap, HashSet, VecDeque};
pub fn message_to_json_schema(_: &DescriptorPool, root_msg: MessageDescriptor) -> JsonSchemaEntry {
JsonSchemaGenerator::generate_json_schema(root_msg)
}
struct JsonSchemaGenerator {
msg_mapping: HashMap<String, JsonSchemaEntry>,
}
impl JsonSchemaGenerator {
pub fn new() -> Self {
JsonSchemaGenerator { msg_mapping: HashMap::new() }
}
pub fn generate_json_schema(msg: MessageDescriptor) -> JsonSchemaEntry {
let generator = JsonSchemaGenerator::new();
generator.scan_root(msg)
}
fn add_message(&mut self, msg: &MessageDescriptor) {
let name = msg.full_name().to_string();
if self.msg_mapping.contains_key(&name) {
return;
}
self.msg_mapping.insert(name.clone(), JsonSchemaEntry::object());
}
pub fn scan_root(mut self, root_msg: MessageDescriptor) -> JsonSchemaEntry {
self.init_structure(root_msg.clone());
self.fill_properties(root_msg.clone());
let mut root = self.msg_mapping.remove(root_msg.full_name()).unwrap();
if self.msg_mapping.len() > 0 {
root.defs = Some(self.msg_mapping);
}
root
}
fn fill_properties(&mut self, root_msg: MessageDescriptor) {
let root_name = root_msg.full_name().to_string();
let mut visited = HashSet::new();
let mut msg_queue = VecDeque::new();
msg_queue.push_back(root_msg);
while !msg_queue.is_empty() {
let msg = msg_queue.pop_front().unwrap();
let msg_name = msg.full_name();
if visited.contains(msg_name) {
continue;
}
visited.insert(msg_name.to_string());
let entry = self.msg_mapping.get_mut(msg_name).unwrap();
for field in msg.fields() {
let field_name = field.name().to_string();
if matches!(field.cardinality(), prost_reflect::Cardinality::Required) {
entry.add_required(field_name.clone());
}
if let Some(oneof) = field.containing_oneof() {
for oneof_field in oneof.fields() {
if let Some(fm) = is_message_field(&oneof_field) {
msg_queue.push_back(fm);
}
entry.add_property(
oneof_field.name().to_string(),
field_to_type_or_ref(&root_name, oneof_field),
);
}
continue;
}
let (field_type, nest_msg) = {
if let Some(fm) = is_message_field(&field) {
if field.is_list() {
// repeated message type
(
JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)),
Some(fm),
)
} else if field.is_map() {
let value_field = fm.get_field_by_name("value").unwrap();
if let Some(fm) = is_message_field(&value_field) {
(
JsonSchemaEntry::map(field_to_type_or_ref(
&root_name,
value_field,
)),
Some(fm),
)
} else {
(
JsonSchemaEntry::map(field_to_type_or_ref(
&root_name,
value_field,
)),
None,
)
}
} else {
(field_to_type_or_ref(&root_name, field), Some(fm))
}
} else {
if field.is_list() {
// repeated scalar type
(JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)), None)
} else {
(field_to_type_or_ref(&root_name, field), None)
}
}
};
if let Some(fm) = nest_msg {
msg_queue.push_back(fm);
}
entry.add_property(field_name, field_type);
}
}
}
fn init_structure(&mut self, root_msg: MessageDescriptor) {
let mut visited = HashSet::new();
let mut msg_queue = VecDeque::new();
msg_queue.push_back(root_msg.clone());
// level traversal, to make sure all message type is defined before used
while !msg_queue.is_empty() {
let msg = msg_queue.pop_front().unwrap();
let name = msg.full_name();
if visited.contains(name) {
continue;
}
visited.insert(name.to_string());
self.add_message(&msg);
for child in msg.child_messages() {
if child.is_map_entry() {
// for field with map<key, value> type, there will be a child message type *Entry generated
// just skip it
continue;
}
self.add_message(&child);
msg_queue.push_back(child);
}
for field in msg.fields() {
if let Some(oneof) = field.containing_oneof() {
for oneof_field in oneof.fields() {
if let Some(fm) = is_message_field(&oneof_field) {
self.add_message(&fm);
msg_queue.push_back(fm);
}
}
continue;
}
if field.is_map() {
// key is always scalar type, so no need to process
// value can be any type, so need to unpack value type
let map_field_msg = is_message_field(&field).unwrap();
let map_value_field = map_field_msg.get_field_by_name("value").unwrap();
if let Some(value_fm) = is_message_field(&map_value_field) {
self.add_message(&value_fm);
msg_queue.push_back(value_fm);
}
continue;
}
if let Some(fm) = is_message_field(&field) {
self.add_message(&fm);
msg_queue.push_back(fm);
}
}
}
}
}
fn field_to_type_or_ref(root_name: &str, field: FieldDescriptor) -> JsonSchemaEntry {
match field.kind() {
prost_reflect::Kind::Bool => JsonSchemaEntry::boolean(),
prost_reflect::Kind::Double => JsonSchemaEntry::number("double"),
prost_reflect::Kind::Float => JsonSchemaEntry::number("float"),
prost_reflect::Kind::Int32 => JsonSchemaEntry::number("int32"),
prost_reflect::Kind::Int64 => JsonSchemaEntry::string_with_format("int64"),
prost_reflect::Kind::Uint32 => JsonSchemaEntry::number("int64"),
prost_reflect::Kind::Uint64 => JsonSchemaEntry::string_with_format("uint64"),
prost_reflect::Kind::Sint32 => JsonSchemaEntry::number("sint32"),
prost_reflect::Kind::Sint64 => JsonSchemaEntry::string_with_format("sint64"),
prost_reflect::Kind::Fixed32 => JsonSchemaEntry::number("int64"),
prost_reflect::Kind::Fixed64 => JsonSchemaEntry::string_with_format("fixed64"),
prost_reflect::Kind::Sfixed32 => JsonSchemaEntry::number("sfixed32"),
prost_reflect::Kind::Sfixed64 => JsonSchemaEntry::string_with_format("sfixed64"),
prost_reflect::Kind::String => JsonSchemaEntry::string(),
prost_reflect::Kind::Bytes => JsonSchemaEntry::string_with_format("byte"),
prost_reflect::Kind::Enum(enums) => {
let values = enums.values().map(|v| v.name().to_string()).collect::<Vec<_>>();
JsonSchemaEntry::enums(values)
}
prost_reflect::Kind::Message(fm) => {
let field_type_full_name = fm.full_name();
match field_type_full_name {
// [Protocol Buffers Well-Known Types]: https://protobuf.dev/reference/protobuf/google.protobuf/
"google.protobuf.FieldMask" => JsonSchemaEntry::string(),
"google.protobuf.Timestamp" => JsonSchemaEntry::string_with_format("date-time"),
"google.protobuf.Duration" => JsonSchemaEntry::string(),
"google.protobuf.StringValue" => JsonSchemaEntry::string(),
"google.protobuf.BytesValue" => JsonSchemaEntry::string_with_format("byte"),
"google.protobuf.Int32Value" => JsonSchemaEntry::number("int32"),
"google.protobuf.UInt32Value" => JsonSchemaEntry::string_with_format("int64"),
"google.protobuf.Int64Value" => JsonSchemaEntry::string_with_format("int64"),
"google.protobuf.UInt64Value" => JsonSchemaEntry::string_with_format("uint64"),
"google.protobuf.FloatValue" => JsonSchemaEntry::number("float"),
"google.protobuf.DoubleValue" => JsonSchemaEntry::number("double"),
"google.protobuf.BoolValue" => JsonSchemaEntry::boolean(),
"google.protobuf.Empty" => JsonSchemaEntry::default(),
"google.protobuf.Struct" => JsonSchemaEntry::object(),
"google.protobuf.ListValue" => JsonSchemaEntry::array(JsonSchemaEntry::default()),
"google.protobuf.NullValue" => JsonSchemaEntry::null(),
name @ _ if name == root_name => JsonSchemaEntry::root_reference(),
_ => JsonSchemaEntry::reference(fm.full_name()),
}
}
}
}
fn is_message_field(field: &FieldDescriptor) -> Option<MessageDescriptor> {
match field.kind() {
prost_reflect::Kind::Message(m) => Some(m),
_ => None,
}
}
#[derive(Default, serde::Serialize)]
#[serde(default, rename_all = "camelCase")]
pub struct JsonSchemaEntry {
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
type_: Option<JsonType>,
#[serde(skip_serializing_if = "Option::is_none")]
format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
properties: Option<HashMap<String, JsonSchemaEntry>>,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
enum_: Option<Vec<String>>,
// for map type
#[serde(skip_serializing_if = "Option::is_none")]
additional_properties: Option<Box<JsonSchemaEntry>>,
// Set all properties to required
#[serde(skip_serializing_if = "Option::is_none")]
required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
items: Option<Box<JsonSchemaEntry>>,
#[serde(skip_serializing_if = "Option::is_none", rename = "$defs")]
defs: Option<HashMap<String, JsonSchemaEntry>>,
#[serde(skip_serializing_if = "Option::is_none", rename = "$ref")]
ref_: Option<String>,
}
impl JsonSchemaEntry {
pub fn add_property(&mut self, name: String, entry: JsonSchemaEntry) {
if self.properties.is_none() {
self.properties = Some(HashMap::new());
}
self.properties.as_mut().unwrap().insert(name, entry);
}
pub fn add_required(&mut self, name: String) {
if self.required.is_none() {
self.required = Some(Vec::new());
}
self.required.as_mut().unwrap().push(name);
}
}
impl JsonSchemaEntry {
pub fn object() -> Self {
JsonSchemaEntry { type_: Some(JsonType::Object), ..Default::default() }
}
pub fn boolean() -> Self {
JsonSchemaEntry { type_: Some(JsonType::Boolean), ..Default::default() }
}
pub fn number<S: Into<String>>(format: S) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::Number),
format: Some(format.into()),
..Default::default()
}
}
pub fn string() -> Self {
JsonSchemaEntry { type_: Some(JsonType::String), ..Default::default() }
}
pub fn string_with_format<S: Into<String>>(format: S) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::String),
format: Some(format.into()),
..Default::default()
}
}
pub fn reference<S: AsRef<str>>(ref_: S) -> Self {
JsonSchemaEntry { ref_: Some(format!("#/$defs/{}", ref_.as_ref())), ..Default::default() }
}
pub fn root_reference() -> Self {
JsonSchemaEntry { ref_: Some("#".to_string()), ..Default::default() }
}
pub fn array(item: JsonSchemaEntry) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::Array),
items: Some(Box::new(item)),
..Default::default()
}
}
pub fn enums(enums: Vec<String>) -> Self {
JsonSchemaEntry { type_: Some(JsonType::String), enum_: Some(enums), ..Default::default() }
}
pub fn map(value_type: JsonSchemaEntry) -> Self {
JsonSchemaEntry {
type_: Some(JsonType::Object),
additional_properties: Some(Box::new(value_type)),
..Default::default()
}
}
pub fn null() -> Self {
JsonSchemaEntry { type_: Some(JsonType::Null), ..Default::default() }
}
}
enum JsonType {
String,
Number,
Object,
Array,
Boolean,
Null,
_UNKNOWN,
}
impl Default for JsonType {
fn default() -> Self {
JsonType::_UNKNOWN
}
}
impl serde::Serialize for JsonType {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
JsonType::String => serializer.serialize_str("string"),
JsonType::Number => serializer.serialize_str("number"),
JsonType::Object => serializer.serialize_str("object"),
JsonType::Array => serializer.serialize_str("array"),
JsonType::Boolean => serializer.serialize_str("boolean"),
JsonType::Null => serializer.serialize_str("null"),
JsonType::_UNKNOWN => serializer.serialize_str("unknown"),
}
}
}

View File

@@ -0,0 +1,54 @@
use prost_reflect::{DynamicMessage, MethodDescriptor, SerializeOptions};
use serde::{Deserialize, Serialize};
use serde_json::Deserializer;
mod any;
mod client;
mod codec;
pub mod error;
mod json_schema;
pub mod manager;
mod reflection;
mod transport;
pub use tonic::Code;
pub use tonic::metadata::*;
pub fn serialize_options() -> SerializeOptions {
SerializeOptions::new().skip_default_fields(false)
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(default, rename_all = "camelCase")]
pub struct ServiceDefinition {
pub name: String,
pub methods: Vec<MethodDefinition>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(default, rename_all = "camelCase")]
pub struct MethodDefinition {
pub name: String,
pub schema: String,
pub client_streaming: bool,
pub server_streaming: bool,
}
static SERIALIZE_OPTIONS: &'static SerializeOptions =
&SerializeOptions::new().skip_default_fields(false).stringify_64_bit_integers(false);
pub fn serialize_message(msg: &DynamicMessage) -> Result<String, String> {
let mut buf = Vec::new();
let mut se = serde_json::Serializer::pretty(&mut buf);
msg.serialize_with_options(&mut se, SERIALIZE_OPTIONS).map_err(|e| e.to_string())?;
let s = String::from_utf8(buf).expect("serde_json to emit valid utf8");
Ok(s)
}
pub fn deserialize_message(msg: &str, method: MethodDescriptor) -> Result<DynamicMessage, String> {
let mut deserializer = Deserializer::from_str(&msg);
let req_message = DynamicMessage::deserialize(method.input(), &mut deserializer)
.map_err(|e| e.to_string())?;
deserializer.end().map_err(|e| e.to_string())?;
Ok(req_message)
}

View File

@@ -0,0 +1,426 @@
use crate::codec::DynamicCodec;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::reflection::{
fill_pool_from_files, fill_pool_from_reflection, method_desc_to_path, reflect_types_for_message,
};
use crate::transport::get_transport;
use crate::{MethodDefinition, ServiceDefinition, json_schema};
use hyper_rustls::HttpsConnector;
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use log::{info, warn};
pub use prost_reflect::DynamicMessage;
use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
use serde_json::Deserializer;
use std::collections::BTreeMap;
use std::error::Error;
use std::fmt;
use std::fmt::Display;
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use tonic::body::BoxBody;
use tonic::metadata::{MetadataKey, MetadataValue};
use tonic::transport::Uri;
use tonic::{IntoRequest, IntoStreamingRequest, Request, Response, Status, Streaming};
use yaak_tls::ClientCertificateConfig;
#[derive(Clone)]
pub struct GrpcConnection {
pool: Arc<RwLock<DescriptorPool>>,
conn: Client<HttpsConnector<HttpConnector>, BoxBody>,
pub uri: Uri,
use_reflection: bool,
}
#[derive(Default, Debug)]
pub struct GrpcStreamError {
pub message: String,
pub status: Option<Status>,
}
impl Error for GrpcStreamError {}
impl Display for GrpcStreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.status {
Some(status) => write!(f, "[{}] {}", status, self.message),
None => write!(f, "{}", self.message),
}
}
}
impl From<String> for GrpcStreamError {
fn from(value: String) -> Self {
GrpcStreamError { message: value.to_string(), status: None }
}
}
impl From<Status> for GrpcStreamError {
fn from(s: Status) -> Self {
GrpcStreamError { message: s.message().to_string(), status: Some(s) }
}
}
impl GrpcConnection {
pub async fn method(&self, service: &str, method: &str) -> Result<MethodDescriptor> {
let service = self.service(service).await?;
let method = service
.methods()
.find(|m| m.name() == method)
.ok_or(GenericError("Failed to find method".to_string()))?;
Ok(method)
}
async fn service(&self, service: &str) -> Result<ServiceDescriptor> {
let pool = self.pool.read().await;
let service = pool
.get_service_by_name(service)
.ok_or(GenericError("Failed to find service".to_string()))?;
Ok(service)
}
pub async fn unary(
&self,
service: &str,
method: &str,
message: &str,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Response<DynamicMessage>> {
if self.use_reflection {
reflect_types_for_message(self.pool.clone(), &self.uri, message, metadata, client_cert)
.await?;
}
let method = &self.method(&service, &method).await?;
let input_message = method.input();
let mut deserializer = Deserializer::from_str(message);
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)?;
deserializer.end()?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = req_message.into_request();
decorate_req(metadata, &mut req)?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client.unary(req, path, codec).await?)
}
pub async fn streaming(
&self,
service: &str,
method: &str,
stream: ReceiverStream<String>,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Response<Streaming<DynamicMessage>>> {
let method = &self.method(&service, &method).await?;
let mapped_stream = {
let input_message = method.input();
let pool = self.pool.clone();
let uri = self.uri.clone();
let md = metadata.clone();
let use_reflection = self.use_reflection.clone();
let client_cert = client_cert.clone();
stream.filter_map(move |json| {
let pool = pool.clone();
let uri = uri.clone();
let input_message = input_message.clone();
let md = md.clone();
let use_reflection = use_reflection.clone();
let client_cert = client_cert.clone();
tokio::runtime::Handle::current().block_on(async move {
if use_reflection {
if let Err(e) =
reflect_types_for_message(pool, &uri, &json, &md, client_cert).await
{
warn!("Failed to resolve Any types: {e}");
}
}
let mut de = Deserializer::from_str(&json);
match DynamicMessage::deserialize(input_message, &mut de) {
Ok(m) => Some(m),
Err(e) => {
warn!("Failed to deserialize message: {e}");
None
}
}
})
})
};
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
let mut req = mapped_stream.into_streaming_request();
decorate_req(metadata, &mut req)?;
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client.streaming(req, path, codec).await?)
}
pub async fn client_streaming(
&self,
service: &str,
method: &str,
stream: ReceiverStream<String>,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Response<DynamicMessage>> {
let method = &self.method(&service, &method).await?;
let mapped_stream = {
let input_message = method.input();
let pool = self.pool.clone();
let uri = self.uri.clone();
let md = metadata.clone();
let use_reflection = self.use_reflection.clone();
let client_cert = client_cert.clone();
stream.filter_map(move |json| {
let pool = pool.clone();
let uri = uri.clone();
let input_message = input_message.clone();
let md = md.clone();
let use_reflection = use_reflection.clone();
let client_cert = client_cert.clone();
tokio::runtime::Handle::current().block_on(async move {
if use_reflection {
if let Err(e) =
reflect_types_for_message(pool, &uri, &json, &md, client_cert).await
{
warn!("Failed to resolve Any types: {e}");
}
}
let mut de = Deserializer::from_str(&json);
match DynamicMessage::deserialize(input_message, &mut de) {
Ok(m) => Some(m),
Err(e) => {
warn!("Failed to deserialize message: {e}");
None
}
}
})
})
};
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
let mut req = mapped_stream.into_streaming_request();
decorate_req(metadata, &mut req)?;
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client
.client_streaming(req, path, codec)
.await
.map_err(|e| GrpcStreamError { message: e.message().to_string(), status: Some(e) })?)
}
pub async fn server_streaming(
&self,
service: &str,
method: &str,
message: &str,
metadata: &BTreeMap<String, String>,
) -> Result<Response<Streaming<DynamicMessage>>> {
let method = &self.method(&service, &method).await?;
let input_message = method.input();
let mut deserializer = Deserializer::from_str(message);
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)?;
deserializer.end()?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = req_message.into_request();
decorate_req(metadata, &mut req)?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.map_err(|e| GenericError(format!("Failed to connect: {}", e)))?;
Ok(client.server_streaming(req, path, codec).await?)
}
}
/// Configuration for GrpcHandle to compile proto files
#[derive(Clone)]
pub struct GrpcConfig {
/// Path to the protoc include directory (vendored/protoc/include)
pub protoc_include_dir: PathBuf,
/// Path to the yaakprotoc sidecar binary
pub protoc_bin_path: PathBuf,
}
pub struct GrpcHandle {
config: GrpcConfig,
pools: BTreeMap<String, DescriptorPool>,
}
impl GrpcHandle {
pub fn new(config: GrpcConfig) -> Self {
let pools = BTreeMap::new();
Self { pools, config }
}
}
impl GrpcHandle {
/// Remove cached descriptor pool for the given key, if present.
pub fn invalidate_pool(&mut self, id: &str, uri: &str, proto_files: &Vec<PathBuf>) {
let key = make_pool_key(id, uri, proto_files);
self.pools.remove(&key);
}
pub async fn reflect(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<bool> {
let server_reflection = proto_files.is_empty();
let key = make_pool_key(id, uri, proto_files);
// If we already have a pool for this key, reuse it and avoid re-reflection
if self.pools.contains_key(&key) {
return Ok(server_reflection);
}
let pool = if server_reflection {
let full_uri = uri_from_str(uri)?;
fill_pool_from_reflection(&full_uri, metadata, validate_certificates, client_cert).await
} else {
fill_pool_from_files(&self.config, proto_files).await
}?;
self.pools.insert(key, pool.clone());
Ok(server_reflection)
}
pub async fn services(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
skip_cache: bool,
) -> Result<Vec<ServiceDefinition>> {
// Ensure we have a pool; reflect only if missing
if skip_cache || self.get_pool(id, uri, proto_files).is_none() {
info!("Reflecting gRPC services for {} at {}", id, uri);
self.reflect(id, uri, proto_files, metadata, validate_certificates, client_cert)
.await?;
}
let pool = self
.get_pool(id, uri, proto_files)
.ok_or(GenericError("Failed to get pool".to_string()))?;
Ok(self.services_from_pool(&pool))
}
fn services_from_pool(&self, pool: &DescriptorPool) -> Vec<ServiceDefinition> {
pool.services()
.map(|s| {
let mut def =
ServiceDefinition { name: s.full_name().to_string(), methods: vec![] };
for method in s.methods() {
let input_message = method.input();
def.methods.push(MethodDefinition {
name: method.name().to_string(),
server_streaming: method.is_server_streaming(),
client_streaming: method.is_client_streaming(),
schema: serde_json::to_string_pretty(&json_schema::message_to_json_schema(
&pool,
input_message,
))
.expect("Failed to serialize JSON schema"),
})
}
def
})
.collect::<Vec<_>>()
}
pub async fn connect(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<GrpcConnection> {
let use_reflection = proto_files.is_empty();
if self.get_pool(id, uri, proto_files).is_none() {
self.reflect(
id,
uri,
proto_files,
metadata,
validate_certificates,
client_cert.clone(),
)
.await?;
}
let pool = self
.get_pool(id, uri, proto_files)
.ok_or(GenericError("Failed to get pool".to_string()))?
.clone();
let uri = uri_from_str(uri)?;
let conn = get_transport(validate_certificates, client_cert.clone())?;
Ok(GrpcConnection { pool: Arc::new(RwLock::new(pool)), use_reflection, conn, uri })
}
fn get_pool(&self, id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> Option<&DescriptorPool> {
self.pools.get(make_pool_key(id, uri, proto_files).as_str())
}
}
pub(crate) fn decorate_req<T>(
metadata: &BTreeMap<String, String>,
req: &mut Request<T>,
) -> Result<()> {
for (k, v) in metadata {
req.metadata_mut()
.insert(MetadataKey::from_str(k.as_str())?, MetadataValue::from_str(v.as_str())?);
}
Ok(())
}
fn uri_from_str(uri_str: &str) -> Result<Uri> {
match Uri::from_str(uri_str) {
Ok(uri) => Ok(uri),
Err(err) => {
// Uri::from_str basically only returns "invalid format" so we add more context here
Err(GenericError(format!("Failed to parse URL, {}", err.to_string())))
}
}
}
fn make_pool_key(id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> String {
let pool_key = format!(
"{}::{}::{}",
id,
uri,
proto_files
.iter()
.map(|p| p.to_string_lossy().to_string())
.collect::<Vec<String>>()
.join(":")
);
format!("{:x}", md5::compute(pool_key))
}

View File

@@ -0,0 +1,452 @@
use crate::any::collect_any_types;
use crate::client::AutoReflectionClient;
use crate::error::Error::GenericError;
use crate::error::Result;
use crate::manager::GrpcConfig;
use anyhow::anyhow;
use async_recursion::async_recursion;
use log::{debug, info, warn};
use prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor};
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use std::collections::{BTreeMap, HashSet};
use std::env::temp_dir;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use tokio::fs;
use tokio::process::Command;
use tokio::sync::RwLock;
use tonic::codegen::http::uri::PathAndQuery;
use tonic::transport::Uri;
use tonic_reflection::pb::v1::server_reflection_request::MessageRequest;
use tonic_reflection::pb::v1::server_reflection_response::MessageResponse;
use yaak_tls::ClientCertificateConfig;
pub async fn fill_pool_from_files(
config: &GrpcConfig,
paths: &Vec<PathBuf>,
) -> Result<DescriptorPool> {
let mut pool = DescriptorPool::new();
let random_file_name = format!("{}.desc", uuid::Uuid::new_v4());
let desc_path = temp_dir().join(random_file_name);
// HACK: Remove UNC prefix for Windows paths
let global_import_dir =
dunce::simplified(config.protoc_include_dir.as_path()).to_string_lossy().to_string();
let desc_path = dunce::simplified(desc_path.as_path());
let mut args = vec![
"--include_imports".to_string(),
"--include_source_info".to_string(),
"-I".to_string(),
global_import_dir,
"-o".to_string(),
desc_path.to_string_lossy().to_string(),
];
let mut include_dirs = HashSet::new();
let mut include_protos = HashSet::new();
for p in paths {
if !p.exists() {
continue;
}
// Dirs are added as includes
if p.is_dir() {
include_dirs.insert(p.to_string_lossy().to_string());
continue;
}
let parent = p.as_path().parent();
if let Some(parent_path) = parent {
match find_parent_proto_dir(parent_path) {
None => {
// Add parent/grandparent as fallback
include_dirs.insert(parent_path.to_string_lossy().to_string());
if let Some(grandparent_path) = parent_path.parent() {
include_dirs.insert(grandparent_path.to_string_lossy().to_string());
}
}
Some(p) => {
include_dirs.insert(p.to_string_lossy().to_string());
}
};
} else {
debug!("ignoring {:?} since it does not exist.", parent)
}
include_protos.insert(p.to_string_lossy().to_string());
}
for d in include_dirs.clone() {
args.push("-I".to_string());
args.push(d);
}
for p in include_protos.clone() {
args.push(p);
}
info!("Invoking protoc with {}", args.join(" "));
let out = Command::new(&config.protoc_bin_path)
.args(&args)
.output()
.await
.map_err(|e| GenericError(format!("Failed to run protoc: {}", e)))?;
if !out.status.success() {
return Err(GenericError(format!(
"protoc failed with status {}: {}",
out.status.code().unwrap_or(-1),
String::from_utf8_lossy(out.stderr.as_slice())
)));
}
let bytes = fs::read(desc_path).await?;
let fdp = FileDescriptorSet::decode(bytes.deref())?;
pool.add_file_descriptor_set(fdp)?;
fs::remove_file(desc_path).await?;
Ok(pool)
}
pub async fn fill_pool_from_reflection(
uri: &Uri,
metadata: &BTreeMap<String, String>,
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<DescriptorPool> {
let mut pool = DescriptorPool::new();
let mut client = AutoReflectionClient::new(uri, validate_certificates, client_cert)?;
for service in list_services(&mut client, metadata).await? {
if service == "grpc.reflection.v1alpha.ServerReflection" {
continue;
}
if service == "grpc.reflection.v1.ServerReflection" {
continue;
}
debug!("Fetching descriptors for {}", service);
file_descriptor_set_from_service_name(&service, &mut pool, &mut client, metadata).await;
}
Ok(pool)
}
async fn list_services(
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) -> Result<Vec<String>> {
let response =
client.send_reflection_request(MessageRequest::ListServices("".into()), metadata).await?;
let list_services_response = match response {
MessageResponse::ListServicesResponse(resp) => resp,
_ => panic!("Expected a ListServicesResponse variant"),
};
Ok(list_services_response.service.iter().map(|s| s.name.clone()).collect::<Vec<_>>())
}
async fn file_descriptor_set_from_service_name(
service_name: &str,
pool: &mut DescriptorPool,
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) {
let response = match client
.send_reflection_request(
MessageRequest::FileContainingSymbol(service_name.into()),
metadata,
)
.await
{
Ok(resp) => resp,
Err(e) => {
warn!("Error fetching file descriptor for service {}: {:?}", service_name, e);
return;
}
};
let file_descriptor_response = match response {
MessageResponse::FileDescriptorResponse(resp) => resp,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
add_file_descriptors_to_pool(
file_descriptor_response.file_descriptor_proto,
pool,
client,
metadata,
)
.await;
}
pub(crate) async fn reflect_types_for_message(
pool: Arc<RwLock<DescriptorPool>>,
uri: &Uri,
json: &str,
metadata: &BTreeMap<String, String>,
client_cert: Option<ClientCertificateConfig>,
) -> Result<()> {
// 1. Collect all Any types in the JSON
let mut extra_types = Vec::new();
collect_any_types(json, &mut extra_types);
if extra_types.is_empty() {
return Ok(()); // nothing to do
}
let mut client = AutoReflectionClient::new(uri, false, client_cert)?;
for extra_type in extra_types {
{
let guard = pool.read().await;
if guard.get_message_by_name(&extra_type).is_some() {
continue;
}
}
info!("Adding file descriptor for {:?} from reflection", extra_type);
let req = MessageRequest::FileContainingSymbol(extra_type.clone().into());
let resp = match client.send_reflection_request(req, metadata).await {
Ok(r) => r,
Err(e) => {
return Err(GenericError(format!(
"Error sending reflection request for @type \"{extra_type}\": {e:?}",
)));
}
};
let files = match resp {
MessageResponse::FileDescriptorResponse(resp) => resp.file_descriptor_proto,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
{
let mut guard = pool.write().await;
add_file_descriptors_to_pool(files, &mut *guard, &mut client, metadata).await;
}
}
Ok(())
}
#[async_recursion]
pub(crate) async fn add_file_descriptors_to_pool(
fds: Vec<Vec<u8>>,
pool: &mut DescriptorPool,
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) {
let mut topo_sort = topology::SimpleTopoSort::new();
let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len());
for fd in fds {
let fdp = FileDescriptorProto::decode(fd.deref()).unwrap();
topo_sort.insert(fdp.name().to_string(), fdp.dependency.clone());
fd_mapping.insert(fdp.name().to_string(), fdp);
}
for node in topo_sort {
match node {
Ok(node) => {
if let Some(fdp) = fd_mapping.remove(&node) {
pool.add_file_descriptor_proto(fdp).expect("add file descriptor proto");
} else {
file_descriptor_set_by_filename(node.as_str(), pool, client, metadata).await;
}
}
Err(_) => panic!("proto file got cycle!"),
}
}
}
async fn file_descriptor_set_by_filename(
filename: &str,
pool: &mut DescriptorPool,
client: &mut AutoReflectionClient,
metadata: &BTreeMap<String, String>,
) {
// We already fetched this file
if let Some(_) = pool.get_file_by_name(filename) {
return;
}
let msg = MessageRequest::FileByFilename(filename.into());
let response = client.send_reflection_request(msg, metadata).await;
let file_descriptor_response = match response {
Ok(MessageResponse::FileDescriptorResponse(resp)) => resp,
Ok(_) => {
panic!("Expected a FileDescriptorResponse variant")
}
Err(e) => {
warn!("Error fetching file descriptor for {}: {:?}", filename, e);
return;
}
};
add_file_descriptors_to_pool(
file_descriptor_response.file_descriptor_proto,
pool,
client,
metadata,
)
.await;
}
pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery {
let full_name = md.full_name();
let (namespace, method_name) = full_name
.rsplit_once('.')
.ok_or_else(|| anyhow!("invalid method path"))
.expect("invalid method path");
PathAndQuery::from_str(&format!("/{}/{}", namespace, method_name)).expect("invalid method path")
}
mod topology {
use std::collections::{HashMap, HashSet};
pub struct SimpleTopoSort<T> {
out_graph: HashMap<T, HashSet<T>>,
in_graph: HashMap<T, HashSet<T>>,
}
impl<T> SimpleTopoSort<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn new() -> Self {
SimpleTopoSort { out_graph: HashMap::new(), in_graph: HashMap::new() }
}
pub fn insert<I: IntoIterator<Item = T>>(&mut self, node: T, deps: I) {
self.out_graph.entry(node.clone()).or_insert(HashSet::new());
for dep in deps {
self.out_graph.entry(node.clone()).or_insert(HashSet::new()).insert(dep.clone());
self.in_graph.entry(dep.clone()).or_insert(HashSet::new()).insert(node.clone());
}
}
}
impl<T> IntoIterator for SimpleTopoSort<T>
where
T: Eq + std::hash::Hash + Clone,
{
type Item = <SimpleTopoSortIter<T> as Iterator>::Item;
type IntoIter = SimpleTopoSortIter<T>;
fn into_iter(self) -> Self::IntoIter {
SimpleTopoSortIter::new(self)
}
}
pub struct SimpleTopoSortIter<T> {
data: SimpleTopoSort<T>,
zero_indegree: Vec<T>,
}
impl<T> SimpleTopoSortIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn new(data: SimpleTopoSort<T>) -> Self {
let mut zero_indegree = Vec::new();
for (node, _) in data.in_graph.iter() {
if !data.out_graph.contains_key(node) {
zero_indegree.push(node.clone());
}
}
for (node, deps) in data.out_graph.iter() {
if deps.is_empty() {
zero_indegree.push(node.clone());
}
}
SimpleTopoSortIter { data, zero_indegree }
}
}
impl<T> Iterator for SimpleTopoSortIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
type Item = Result<T, &'static str>;
fn next(&mut self) -> Option<Self::Item> {
if self.zero_indegree.is_empty() {
if self.data.out_graph.is_empty() {
return None;
}
return Some(Err("Cycle detected"));
}
let node = self.zero_indegree.pop().unwrap();
if let Some(parents) = self.data.in_graph.get(&node) {
for parent in parents.iter() {
let deps = self.data.out_graph.get_mut(parent).unwrap();
deps.remove(&node);
if deps.is_empty() {
self.zero_indegree.push(parent.clone());
}
}
}
self.data.out_graph.remove(&node);
Some(Ok(node))
}
}
#[test]
fn test_sort() {
{
let mut topo_sort = SimpleTopoSort::new();
topo_sort.insert("a", []);
for node in topo_sort {
match node {
Ok(n) => assert_eq!(n, "a"),
Err(e) => panic!("err {}", e),
}
}
}
{
let mut topo_sort = SimpleTopoSort::new();
topo_sort.insert("a", ["b"]);
topo_sort.insert("b", []);
let mut iter = topo_sort.into_iter();
match iter.next() {
Some(Ok(n)) => assert_eq!(n, "b"),
_ => panic!("err"),
}
match iter.next() {
Some(Ok(n)) => assert_eq!(n, "a"),
_ => panic!("err"),
}
assert_eq!(iter.next(), None);
}
}
}
fn find_parent_proto_dir(start_path: impl AsRef<Path>) -> Option<PathBuf> {
let mut dir = start_path.as_ref().canonicalize().ok()?;
loop {
if let Some(name) = dir.file_name().and_then(|n| n.to_str()) {
if name == "proto" {
return Some(dir);
}
}
let parent = dir.parent()?;
if parent == dir {
return None; // Reached root
}
dir = parent.to_path_buf();
}
}

View File

@@ -0,0 +1,40 @@
use crate::error::Result;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
use log::info;
use tonic::body::BoxBody;
use yaak_tls::{ClientCertificateConfig, get_tls_config};
// I think ALPN breaks this because we're specifying http2_only
const WITH_ALPN: bool = false;
pub(crate) fn get_transport(
validate_certificates: bool,
client_cert: Option<ClientCertificateConfig>,
) -> Result<Client<HttpsConnector<HttpConnector>, BoxBody>> {
let tls_config = get_tls_config(validate_certificates, WITH_ALPN, client_cert.clone())?;
let mut http = HttpConnector::new();
http.enforce_http(false);
let connector = HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http2()
.build();
let client = Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(0)
.http2_only(true)
.build(connector);
info!(
"Created gRPC client validate_certs={} client_cert={}",
validate_certificates,
client_cert.is_some()
);
Ok(client)
}