Filesystem Sync (#142)

This commit is contained in:
Gregory Schier
2025-01-03 20:41:00 -08:00
committed by GitHub
parent 6ad27c4458
commit 31440eea76
159 changed files with 4296 additions and 1016 deletions

View File

@@ -0,0 +1,52 @@
use prost_reflect::prost::Message;
use prost_reflect::{DynamicMessage, MethodDescriptor};
use tonic::codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
use tonic::Status;
#[derive(Clone)]
pub struct DynamicCodec(MethodDescriptor);
impl DynamicCodec {
#[allow(dead_code)]
pub fn new(md: MethodDescriptor) -> Self {
Self(md)
}
}
impl Codec for DynamicCodec {
type Encode = DynamicMessage;
type Decode = DynamicMessage;
type Encoder = Self;
type Decoder = Self;
fn encoder(&mut self) -> Self::Encoder {
self.clone()
}
fn decoder(&mut self) -> Self::Decoder {
self.clone()
}
}
impl Encoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
item.encode(dst)
.expect("buffer is too small to decode this message");
Ok(())
}
}
impl Decoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
let mut msg = DynamicMessage::new(self.0.output());
msg.merge(src)
.map_err(|err| Status::internal(err.to_string()))?;
Ok(Some(msg))
}
}

View File

@@ -0,0 +1,179 @@
use prost_reflect::{DescriptorPool, MessageDescriptor};
use prost_types::field_descriptor_proto;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Default, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct JsonSchemaEntry {
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(rename = "type")]
type_: JsonType,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
properties: Option<HashMap<String, JsonSchemaEntry>>,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
enum_: Option<Vec<String>>,
/// Don't allow any other properties in the object
additional_properties: bool,
/// Set all properties to required
#[serde(skip_serializing_if = "Option::is_none")]
required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
items: Option<Box<JsonSchemaEntry>>,
}
enum JsonType {
String,
Number,
Object,
Array,
Boolean,
Null,
_UNKNOWN,
}
impl Default for JsonType {
fn default() -> Self {
JsonType::_UNKNOWN
}
}
impl serde::Serialize for JsonType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
JsonType::String => serializer.serialize_str("string"),
JsonType::Number => serializer.serialize_str("number"),
JsonType::Object => serializer.serialize_str("object"),
JsonType::Array => serializer.serialize_str("array"),
JsonType::Boolean => serializer.serialize_str("boolean"),
JsonType::Null => serializer.serialize_str("null"),
JsonType::_UNKNOWN => serializer.serialize_str("unknown"),
}
}
}
impl<'de> serde::Deserialize<'de> for JsonType {
fn deserialize<D>(deserializer: D) -> Result<JsonType, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.as_str() {
"string" => Ok(JsonType::String),
"number" => Ok(JsonType::Number),
"object" => Ok(JsonType::Object),
"array" => Ok(JsonType::Array),
"boolean" => Ok(JsonType::Boolean),
"null" => Ok(JsonType::Null),
_ => Ok(JsonType::_UNKNOWN),
}
}
}
pub fn message_to_json_schema(
pool: &DescriptorPool,
message: MessageDescriptor,
) -> JsonSchemaEntry {
let mut schema = JsonSchemaEntry {
title: Some(message.name().to_string()),
type_: JsonType::Object, // Messages are objects
..Default::default()
};
let mut properties = HashMap::new();
message.fields().for_each(|f| match f.kind() {
prost_reflect::Kind::Message(m) => {
properties.insert(f.name().to_string(), message_to_json_schema(pool, m));
}
prost_reflect::Kind::Enum(e) => {
properties.insert(
f.name().to_string(),
JsonSchemaEntry {
type_: map_proto_type_to_json_type(f.field_descriptor_proto().r#type()),
enum_: Some(e.values().map(|v| v.name().to_string()).collect::<Vec<_>>()),
..Default::default()
},
);
}
_ => {
// TODO: Handle repeated label
match f.field_descriptor_proto().label() {
field_descriptor_proto::Label::Repeated => {
// TODO: Handle more complex repeated types. This just handles primitives for now
properties.insert(
f.name().to_string(),
JsonSchemaEntry {
type_: JsonType::Array,
items: Some(Box::new(JsonSchemaEntry {
type_: map_proto_type_to_json_type(
f.field_descriptor_proto().r#type(),
),
..Default::default()
})),
..Default::default()
},
);
}
_ => {
// Regular JSON field
properties.insert(
f.name().to_string(),
JsonSchemaEntry {
type_: map_proto_type_to_json_type(f.field_descriptor_proto().r#type()),
..Default::default()
},
);
}
};
}
});
schema.properties = Some(properties);
// All proto 3 fields are optional, so maybe we could
// make this a setting?
// schema.required = Some(
// message
// .fields()
// .map(|f| f.name().to_string())
// .collect::<Vec<_>>(),
// );
schema
}
fn map_proto_type_to_json_type(proto_type: field_descriptor_proto::Type) -> JsonType {
match proto_type {
field_descriptor_proto::Type::Double => JsonType::Number,
field_descriptor_proto::Type::Float => JsonType::Number,
field_descriptor_proto::Type::Int64 => JsonType::Number,
field_descriptor_proto::Type::Uint64 => JsonType::Number,
field_descriptor_proto::Type::Int32 => JsonType::Number,
field_descriptor_proto::Type::Fixed64 => JsonType::Number,
field_descriptor_proto::Type::Fixed32 => JsonType::Number,
field_descriptor_proto::Type::Bool => JsonType::Boolean,
field_descriptor_proto::Type::String => JsonType::String,
field_descriptor_proto::Type::Group => JsonType::_UNKNOWN,
field_descriptor_proto::Type::Message => JsonType::Object,
field_descriptor_proto::Type::Bytes => JsonType::String,
field_descriptor_proto::Type::Uint32 => JsonType::Number,
field_descriptor_proto::Type::Enum => JsonType::String,
field_descriptor_proto::Type::Sfixed32 => JsonType::Number,
field_descriptor_proto::Type::Sfixed64 => JsonType::Number,
field_descriptor_proto::Type::Sint32 => JsonType::Number,
field_descriptor_proto::Type::Sint64 => JsonType::Number,
}
}

View File

@@ -0,0 +1,52 @@
use prost_reflect::{DynamicMessage, MethodDescriptor, SerializeOptions};
use serde::{Deserialize, Serialize};
use serde_json::Deserializer;
mod codec;
mod json_schema;
pub mod manager;
mod proto;
pub use tonic::metadata::*;
pub use tonic::Code;
pub fn serialize_options() -> SerializeOptions {
SerializeOptions::new().skip_default_fields(false)
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(default, rename_all = "camelCase")]
pub struct ServiceDefinition {
pub name: String,
pub methods: Vec<MethodDefinition>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[serde(default, rename_all = "camelCase")]
pub struct MethodDefinition {
pub name: String,
pub schema: String,
pub client_streaming: bool,
pub server_streaming: bool,
}
static SERIALIZE_OPTIONS: &'static SerializeOptions = &SerializeOptions::new()
.skip_default_fields(false)
.stringify_64_bit_integers(false);
pub fn serialize_message(msg: &DynamicMessage) -> Result<String, String> {
let mut buf = Vec::new();
let mut se = serde_json::Serializer::pretty(&mut buf);
msg.serialize_with_options(&mut se, SERIALIZE_OPTIONS)
.map_err(|e| e.to_string())?;
let s = String::from_utf8(buf).expect("serde_json to emit valid utf8");
Ok(s)
}
pub fn deserialize_message(msg: &str, method: MethodDescriptor) -> Result<DynamicMessage, String> {
let mut deserializer = Deserializer::from_str(&msg);
let req_message = DynamicMessage::deserialize(method.input(), &mut deserializer)
.map_err(|e| e.to_string())?;
deserializer.end().map_err(|e| e.to_string())?;
Ok(req_message)
}

View File

@@ -0,0 +1,304 @@
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::str::FromStr;
use hyper::client::HttpConnector;
use hyper::Client;
use hyper_rustls::HttpsConnector;
pub use prost_reflect::DynamicMessage;
use prost_reflect::{DescriptorPool, MethodDescriptor, ServiceDescriptor};
use serde_json::Deserializer;
use tauri::AppHandle;
use tokio_stream::wrappers::ReceiverStream;
use tonic::body::BoxBody;
use tonic::metadata::{MetadataKey, MetadataValue};
use tonic::transport::Uri;
use tonic::{IntoRequest, IntoStreamingRequest, Request, Response, Status, Streaming};
use crate::codec::DynamicCodec;
use crate::proto::{
fill_pool_from_files, fill_pool_from_reflection, get_transport, method_desc_to_path,
};
use crate::{json_schema, MethodDefinition, ServiceDefinition};
#[derive(Clone)]
pub struct GrpcConnection {
pool: DescriptorPool,
conn: Client<HttpsConnector<HttpConnector>, BoxBody>,
pub uri: Uri,
}
#[derive(Default, Debug)]
pub struct StreamError {
pub message: String,
pub status: Option<Status>,
}
impl From<String> for StreamError {
fn from(value: String) -> Self {
StreamError {
message: value.to_string(),
status: None,
}
}
}
impl From<Status> for StreamError {
fn from(s: Status) -> Self {
StreamError {
message: s.message().to_string(),
status: Some(s),
}
}
}
impl GrpcConnection {
pub fn service(&self, service: &str) -> Result<ServiceDescriptor, String> {
let service = self
.pool
.get_service_by_name(service)
.ok_or("Failed to find service")?;
Ok(service)
}
pub fn method(&self, service: &str, method: &str) -> Result<MethodDescriptor, String> {
let service = self.service(service)?;
let method = service
.methods()
.find(|m| m.name() == method)
.ok_or("Failed to find method")?;
Ok(method)
}
pub async fn unary(
&self,
service: &str,
method: &str,
message: &str,
metadata: BTreeMap<String, String>,
) -> Result<Response<DynamicMessage>, StreamError> {
let method = &self.method(&service, &method)?;
let input_message = method.input();
let mut deserializer = Deserializer::from_str(message);
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)
.map_err(|e| e.to_string())?;
deserializer.end().unwrap();
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = req_message.into_request();
decorate_req(metadata, &mut req).map_err(|e| e.to_string())?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.unwrap();
Ok(client.unary(req, path, codec).await?)
}
pub async fn streaming(
&self,
service: &str,
method: &str,
stream: ReceiverStream<DynamicMessage>,
metadata: BTreeMap<String, String>,
) -> Result<Response<Streaming<DynamicMessage>>, StreamError> {
let method = &self.method(&service, &method)?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = stream.into_streaming_request();
decorate_req(metadata, &mut req).map_err(|e| e.to_string())?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.map_err(|e| e.to_string())?;
Ok(client.streaming(req, path, codec).await?)
}
pub async fn client_streaming(
&self,
service: &str,
method: &str,
stream: ReceiverStream<DynamicMessage>,
metadata: BTreeMap<String, String>,
) -> Result<Response<DynamicMessage>, StreamError> {
let method = &self.method(&service, &method)?;
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = stream.into_streaming_request();
decorate_req(metadata, &mut req).map_err(|e| e.to_string())?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.unwrap();
client
.client_streaming(req, path, codec)
.await
.map_err(|e| StreamError {
message: e.message().to_string(),
status: Some(e),
})
}
pub async fn server_streaming(
&self,
service: &str,
method: &str,
message: &str,
metadata: BTreeMap<String, String>,
) -> Result<Response<Streaming<DynamicMessage>>, StreamError> {
let method = &self.method(&service, &method)?;
let input_message = method.input();
let mut deserializer = Deserializer::from_str(message);
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer)
.map_err(|e| e.to_string())?;
deserializer.end().unwrap();
let mut client = tonic::client::Grpc::with_origin(self.conn.clone(), self.uri.clone());
let mut req = req_message.into_request();
decorate_req(metadata, &mut req).map_err(|e| e.to_string())?;
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.map_err(|e| e.to_string())?;
Ok(client.server_streaming(req, path, codec).await?)
}
}
pub struct GrpcHandle {
app_handle: AppHandle,
pools: BTreeMap<String, DescriptorPool>,
}
impl GrpcHandle {
pub fn new(app_handle: &AppHandle) -> Self {
let pools = BTreeMap::new();
Self {
pools,
app_handle: app_handle.clone(),
}
}
}
impl GrpcHandle {
pub async fn reflect(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
) -> Result<(), String> {
let pool = if proto_files.is_empty() {
let full_uri = uri_from_str(uri)?;
fill_pool_from_reflection(&full_uri).await
} else {
fill_pool_from_files(&self.app_handle, proto_files).await
}?;
self.pools
.insert(make_pool_key(id, uri, proto_files), pool.clone());
Ok(())
}
pub async fn services(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
) -> Result<Vec<ServiceDefinition>, String> {
// Ensure reflection is up-to-date
self.reflect(id, uri, proto_files).await?;
let pool = self
.get_pool(id, uri, proto_files)
.ok_or("Failed to get pool".to_string())?;
Ok(self.services_from_pool(&pool))
}
fn services_from_pool(&self, pool: &DescriptorPool) -> Vec<ServiceDefinition> {
pool.services()
.map(|s| {
let mut def = ServiceDefinition {
name: s.full_name().to_string(),
methods: vec![],
};
for method in s.methods() {
let input_message = method.input();
def.methods.push(MethodDefinition {
name: method.name().to_string(),
server_streaming: method.is_server_streaming(),
client_streaming: method.is_client_streaming(),
schema: serde_json::to_string_pretty(&json_schema::message_to_json_schema(
&pool,
input_message,
))
.unwrap(),
})
}
def
})
.collect::<Vec<_>>()
}
pub async fn connect(
&mut self,
id: &str,
uri: &str,
proto_files: &Vec<PathBuf>,
) -> Result<GrpcConnection, String> {
self.reflect(id, uri, proto_files).await?;
let pool = self
.get_pool(id, uri, proto_files)
.ok_or("Failed to get pool")?;
let uri = uri_from_str(uri)?;
let conn = get_transport();
let connection = GrpcConnection {
pool: pool.clone(),
conn,
uri,
};
Ok(connection)
}
fn get_pool(&self, id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> Option<&DescriptorPool> {
self.pools.get(make_pool_key(id, uri, proto_files).as_str())
}
}
fn decorate_req<T>(metadata: BTreeMap<String, String>, req: &mut Request<T>) -> Result<(), String> {
for (k, v) in metadata {
req.metadata_mut().insert(
MetadataKey::from_str(k.as_str()).map_err(|e| e.to_string())?,
MetadataValue::from_str(v.as_str()).map_err(|e| e.to_string())?,
);
}
Ok(())
}
fn uri_from_str(uri_str: &str) -> Result<Uri, String> {
match Uri::from_str(uri_str) {
Ok(uri) => Ok(uri),
Err(err) => {
// Uri::from_str basically only returns "invalid format" so we add more context here
Err(format!("Failed to parse URL, {}", err.to_string()))
}
}
}
fn make_pool_key(id: &str, uri: &str, proto_files: &Vec<PathBuf>) -> String {
let pool_key = format!(
"{}::{}::{}",
id,
uri,
proto_files
.iter()
.map(|p| p.to_string_lossy().to_string())
.collect::<Vec<String>>()
.join(":")
);
format!("{:x}", md5::compute(pool_key))
}

View File

@@ -0,0 +1,394 @@
use std::env::temp_dir;
use std::ops::Deref;
use std::path::PathBuf;
use std::str::FromStr;
use anyhow::anyhow;
use async_recursion::async_recursion;
use hyper::client::HttpConnector;
use hyper::Client;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use log::{debug, warn};
use prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor};
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use tauri::path::BaseDirectory;
use tauri::{AppHandle, Manager};
use tauri_plugin_shell::ShellExt;
use tokio::fs;
use tokio_stream::StreamExt;
use tonic::body::BoxBody;
use tonic::codegen::http::uri::PathAndQuery;
use tonic::transport::Uri;
use tonic::Request;
use tonic_reflection::pb::server_reflection_client::ServerReflectionClient;
use tonic_reflection::pb::server_reflection_request::MessageRequest;
use tonic_reflection::pb::server_reflection_response::MessageResponse;
use tonic_reflection::pb::ServerReflectionRequest;
pub async fn fill_pool_from_files(
app_handle: &AppHandle,
paths: &Vec<PathBuf>,
) -> Result<DescriptorPool, String> {
let mut pool = DescriptorPool::new();
let random_file_name = format!("{}.desc", uuid::Uuid::new_v4());
let desc_path = temp_dir().join(random_file_name);
let global_import_dir = app_handle
.path()
.resolve("vendored/protoc/protoc-include", BaseDirectory::Resource)
.expect("failed to resolve protoc include directory");
// HACK: Remove UNC prefix for Windows paths
let global_import_dir =
dunce::simplified(global_import_dir.as_path()).to_string_lossy().to_string();
let desc_path = dunce::simplified(desc_path.as_path());
let mut args = vec![
"--include_imports".to_string(),
"--include_source_info".to_string(),
"-I".to_string(),
global_import_dir,
"-o".to_string(),
desc_path.to_string_lossy().to_string(),
];
for p in paths {
if p.as_path().exists() {
args.push(p.to_string_lossy().to_string());
} else {
continue;
}
let parent = p.as_path().parent();
if let Some(parent_path) = parent {
args.push("-I".to_string());
args.push(parent_path.to_string_lossy().to_string());
args.push("-I".to_string());
args.push(parent_path.parent().unwrap().to_string_lossy().to_string());
} else {
debug!("ignoring {:?} since it does not exist.", parent)
}
}
let out = app_handle
.shell()
.sidecar("yaakprotoc")
.expect("yaakprotoc not found")
.args(args)
.output()
.await
.expect("yaakprotoc failed to run");
if !out.status.success() {
return Err(format!(
"protoc failed with status {}: {}",
out.status.code().unwrap(),
String::from_utf8_lossy(out.stderr.as_slice())
));
}
let bytes = fs::read(desc_path).await.map_err(|e| e.to_string())?;
let fdp = FileDescriptorSet::decode(bytes.deref()).map_err(|e| e.to_string())?;
pool.add_file_descriptor_set(fdp).map_err(|e| e.to_string())?;
fs::remove_file(desc_path).await.map_err(|e| e.to_string())?;
Ok(pool)
}
pub async fn fill_pool_from_reflection(uri: &Uri) -> Result<DescriptorPool, String> {
let mut pool = DescriptorPool::new();
let mut client = ServerReflectionClient::with_origin(get_transport(), uri.clone());
for service in list_services(&mut client).await? {
if service == "grpc.reflection.v1alpha.ServerReflection" {
continue;
}
if service == "grpc.reflection.v1.ServerReflection"{
// TODO: update reflection client to use v1
continue;
}
file_descriptor_set_from_service_name(&service, &mut pool, &mut client).await;
}
Ok(pool)
}
pub fn get_transport() -> Client<HttpsConnector<HttpConnector>, BoxBody> {
let connector = HttpsConnectorBuilder::new().with_native_roots();
let connector = connector.https_or_http().enable_http2().wrap_connector({
let mut http_connector = HttpConnector::new();
http_connector.enforce_http(false);
http_connector
});
Client::builder().pool_max_idle_per_host(0).http2_only(true).build(connector)
}
async fn list_services(
reflect_client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
) -> Result<Vec<String>, String> {
let response =
send_reflection_request(reflect_client, MessageRequest::ListServices("".into())).await?;
let list_services_response = match response {
MessageResponse::ListServicesResponse(resp) => resp,
_ => panic!("Expected a ListServicesResponse variant"),
};
Ok(list_services_response.service.iter().map(|s| s.name.clone()).collect::<Vec<_>>())
}
async fn file_descriptor_set_from_service_name(
service_name: &str,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
) {
let response = match send_reflection_request(
client,
MessageRequest::FileContainingSymbol(service_name.into()),
)
.await
{
Ok(resp) => resp,
Err(e) => {
warn!("Error fetching file descriptor for service {}: {}", service_name, e);
return;
}
};
let file_descriptor_response = match response {
MessageResponse::FileDescriptorResponse(resp) => resp,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
add_file_descriptors_to_pool(file_descriptor_response.file_descriptor_proto, pool, client)
.await;
}
#[async_recursion]
async fn add_file_descriptors_to_pool(
fds: Vec<Vec<u8>>,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
) {
let mut topo_sort = topology::SimpleTopoSort::new();
let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len());
for fd in fds {
let fdp = FileDescriptorProto::decode(fd.deref()).unwrap();
topo_sort.insert(fdp.name().to_string(), fdp.dependency.clone());
fd_mapping.insert(fdp.name().to_string(), fdp);
}
for node in topo_sort {
match node {
Ok(node) => {
if let Some(fdp) = fd_mapping.remove(&node) {
pool.add_file_descriptor_proto(fdp).expect("add file descriptor proto");
} else {
file_descriptor_set_by_filename(node.as_str(), pool, client).await;
}
}
Err(_) => panic!("proto file got cycle!"),
}
}
}
async fn file_descriptor_set_by_filename(
filename: &str,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
) {
// We already fetched this file
if let Some(_) = pool.get_file_by_name(filename) {
return;
}
let response =
send_reflection_request(client, MessageRequest::FileByFilename(filename.into())).await;
let file_descriptor_response = match response {
Ok(MessageResponse::FileDescriptorResponse(resp)) => resp,
Ok(_) => {
panic!("Expected a FileDescriptorResponse variant")
}
Err(e) => {
warn!("Error fetching file descriptor for {}: {}", filename, e);
return;
}
};
add_file_descriptors_to_pool(file_descriptor_response.file_descriptor_proto, pool, client)
.await;
}
async fn send_reflection_request(
client: &mut ServerReflectionClient<Client<HttpsConnector<HttpConnector>, BoxBody>>,
message: MessageRequest,
) -> Result<MessageResponse, String> {
let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter
message_request: Some(message),
};
let request = Request::new(tokio_stream::once(reflection_request));
client
.server_reflection_info(request)
.await
.map_err(|e| match e.code() {
tonic::Code::Unavailable => "Failed to connect to endpoint".to_string(),
tonic::Code::Unauthenticated => "Authentication failed".to_string(),
tonic::Code::DeadlineExceeded => "Deadline exceeded".to_string(),
_ => e.to_string(),
})?
.into_inner()
.next()
.await
.expect("steamed response")
.map_err(|e| e.to_string())?
.message_response
.ok_or("No reflection response".to_string())
}
pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery {
let full_name = md.full_name();
let (namespace, method_name) = full_name
.rsplit_once('.')
.ok_or_else(|| anyhow!("invalid method path"))
.expect("invalid method path");
PathAndQuery::from_str(&format!("/{}/{}", namespace, method_name)).expect("invalid method path")
}
mod topology {
use std::collections::{HashMap, HashSet};
pub struct SimpleTopoSort<T> {
out_graph: HashMap<T, HashSet<T>>,
in_graph: HashMap<T, HashSet<T>>,
}
impl<T> SimpleTopoSort<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn new() -> Self {
SimpleTopoSort {
out_graph: HashMap::new(),
in_graph: HashMap::new(),
}
}
pub fn insert<I: IntoIterator<Item = T>>(&mut self, node: T, deps: I) {
self.out_graph.entry(node.clone()).or_insert(HashSet::new());
for dep in deps {
self.out_graph.entry(node.clone()).or_insert(HashSet::new()).insert(dep.clone());
self.in_graph.entry(dep.clone()).or_insert(HashSet::new()).insert(node.clone());
}
}
}
impl<T> IntoIterator for SimpleTopoSort<T>
where
T: Eq + std::hash::Hash + Clone,
{
type IntoIter = SimpleTopoSortIter<T>;
type Item = <SimpleTopoSortIter<T> as Iterator>::Item;
fn into_iter(self) -> Self::IntoIter {
SimpleTopoSortIter::new(self)
}
}
pub struct SimpleTopoSortIter<T> {
data: SimpleTopoSort<T>,
zero_indegree: Vec<T>,
}
impl<T> SimpleTopoSortIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn new(data: SimpleTopoSort<T>) -> Self {
let mut zero_indegree = Vec::new();
for (node, _) in data.in_graph.iter() {
if !data.out_graph.contains_key(node) {
zero_indegree.push(node.clone());
}
}
for (node, deps) in data.out_graph.iter(){
if deps.is_empty(){
zero_indegree.push(node.clone());
}
}
SimpleTopoSortIter {
data,
zero_indegree,
}
}
}
impl<T> Iterator for SimpleTopoSortIter<T>
where
T: Eq + std::hash::Hash + Clone,
{
type Item = Result<T, &'static str>;
fn next(&mut self) -> Option<Self::Item> {
if self.zero_indegree.is_empty() {
if self.data.out_graph.is_empty() {
return None;
}
return Some(Err("Cycle detected"));
}
let node = self.zero_indegree.pop().unwrap();
if let Some(parents) = self.data.in_graph.get(&node){
for parent in parents.iter(){
let deps = self.data.out_graph.get_mut(parent).unwrap();
deps.remove(&node);
if deps.is_empty() {
self.zero_indegree.push(parent.clone());
}
}
}
self.data.out_graph.remove(&node);
Some(Ok(node))
}
}
#[test]
fn test_sort(){
{
let mut topo_sort = SimpleTopoSort::new();
topo_sort.insert("a", []);
for node in topo_sort {
match node {
Ok(n) => assert_eq!(n, "a"),
Err(e) => panic!("err {}", e),
}
}
}
{
let mut topo_sort = SimpleTopoSort::new();
topo_sort.insert("a", ["b"]);
topo_sort.insert("b", []);
let mut iter = topo_sort.into_iter();
match iter.next() {
Some(Ok(n)) => assert_eq!(n, "b"),
_ => panic!("err"),
}
match iter.next() {
Some(Ok(n)) => assert_eq!(n, "a"),
_ => panic!("err"),
}
assert_eq!(iter.next(), None);
}
}
}