From 1050ac5e3cdfefb307d8e9baa8b837c29eff1ecb Mon Sep 17 00:00:00 2001 From: Hao Xiang Date: Wed, 30 Oct 2024 05:19:11 +0800 Subject: [PATCH] fix(grpc): proto dep topo order to solve panic (#130) --- src-tauri/Cargo.lock | 1 + src-tauri/yaak_grpc/Cargo.toml | 1 + src-tauri/yaak_grpc/src/proto.rs | 212 ++++++++++++++++++++++++++----- 3 files changed, 179 insertions(+), 35 deletions(-) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 0e9caf64..a69c23f8 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -8038,6 +8038,7 @@ name = "yaak_grpc" version = "0.1.0" dependencies = [ "anyhow", + "async-recursion", "dunce", "hyper 0.14.30", "hyper-rustls 0.24.2", diff --git a/src-tauri/yaak_grpc/Cargo.toml b/src-tauri/yaak_grpc/Cargo.toml index a8f2ae39..5a3359f0 100644 --- a/src-tauri/yaak_grpc/Cargo.toml +++ b/src-tauri/yaak_grpc/Cargo.toml @@ -22,3 +22,4 @@ tauri = { workspace = true } tauri-plugin-shell = { workspace = true } md5 = "0.7.0" dunce = "1.0.4" +async-recursion = "1.1.1" \ No newline at end of file diff --git a/src-tauri/yaak_grpc/src/proto.rs b/src-tauri/yaak_grpc/src/proto.rs index ef300908..8217c7e2 100644 --- a/src-tauri/yaak_grpc/src/proto.rs +++ b/src-tauri/yaak_grpc/src/proto.rs @@ -4,6 +4,7 @@ use std::path::PathBuf; use std::str::FromStr; use anyhow::anyhow; +use async_recursion::async_recursion; use hyper::client::HttpConnector; use hyper::Client; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; @@ -38,9 +39,8 @@ pub async fn fill_pool_from_files( .expect("failed to resolve protoc include directory"); // HACK: Remove UNC prefix for Windows paths - let global_import_dir = dunce::simplified(global_import_dir.as_path()) - .to_string_lossy() - .to_string(); + let global_import_dir = + dunce::simplified(global_import_dir.as_path()).to_string_lossy().to_string(); let desc_path = dunce::simplified(desc_path.as_path()); let mut args = vec![ @@ -89,12 +89,9 @@ pub async fn fill_pool_from_files( let bytes = fs::read(desc_path).await.map_err(|e| e.to_string())?; let fdp = FileDescriptorSet::decode(bytes.deref()).map_err(|e| e.to_string())?; - pool.add_file_descriptor_set(fdp) - .map_err(|e| e.to_string())?; + pool.add_file_descriptor_set(fdp).map_err(|e| e.to_string())?; - fs::remove_file(desc_path) - .await - .map_err(|e| e.to_string())?; + fs::remove_file(desc_path).await.map_err(|e| e.to_string())?; Ok(pool) } @@ -107,6 +104,10 @@ pub async fn fill_pool_from_reflection(uri: &Uri) -> Result Client, BoxBody> { http_connector.enforce_http(false); http_connector }); - Client::builder() - .pool_max_idle_per_host(0) - .http2_only(true) - .build(connector) + Client::builder().pool_max_idle_per_host(0).http2_only(true).build(connector) } async fn list_services( @@ -137,11 +135,7 @@ async fn list_services( _ => panic!("Expected a ListServicesResponse variant"), }; - Ok(list_services_response - .service - .iter() - .map(|s| s.name.clone()) - .collect::>()) + Ok(list_services_response.service.iter().map(|s| s.name.clone()).collect::>()) } async fn file_descriptor_set_from_service_name( @@ -153,14 +147,11 @@ async fn file_descriptor_set_from_service_name( client, MessageRequest::FileContainingSymbol(service_name.into()), ) - .await + .await { Ok(resp) => resp, Err(e) => { - warn!( - "Error fetching file descriptor for service {}: {}", - service_name, e - ); + warn!("Error fetching file descriptor for service {}: {}", service_name, e); return; } }; @@ -170,16 +161,37 @@ async fn file_descriptor_set_from_service_name( _ => panic!("Expected a FileDescriptorResponse variant"), }; - for fd in file_descriptor_response.file_descriptor_proto { + add_file_descriptors_to_pool(file_descriptor_response.file_descriptor_proto, pool, client) + .await; +} + +#[async_recursion] +async fn add_file_descriptors_to_pool( + fds: Vec>, + pool: &mut DescriptorPool, + client: &mut ServerReflectionClient, BoxBody>>, +) { + let mut topo_sort = topology::SimpleTopoSort::new(); + let mut fd_mapping = std::collections::HashMap::with_capacity(fds.len()); + + for fd in fds { let fdp = FileDescriptorProto::decode(fd.deref()).unwrap(); - // Add deps first or else we'll get an error - for dep_name in fdp.clone().dependency { - file_descriptor_set_by_filename(&dep_name, pool, client).await; - } + topo_sort.insert(fdp.name().to_string(), fdp.dependency.clone()); + fd_mapping.insert(fdp.name().to_string(), fdp); + } - pool.add_file_descriptor_proto(fdp) - .expect("add file descriptor proto"); + for node in topo_sort { + match node { + Ok(node) => { + if let Some(fdp) = fd_mapping.remove(&node) { + pool.add_file_descriptor_proto(fdp).expect("add file descriptor proto"); + } else { + file_descriptor_set_by_filename(node.as_str(), pool, client).await; + } + } + Err(_) => panic!("proto file got cycle!"), + } } } @@ -206,11 +218,8 @@ async fn file_descriptor_set_by_filename( } }; - for fd in file_descriptor_response.file_descriptor_proto { - let fdp = FileDescriptorProto::decode(fd.deref()).unwrap(); - pool.add_file_descriptor_proto(fdp) - .expect("add file descriptor proto"); - } + add_file_descriptors_to_pool(file_descriptor_response.file_descriptor_proto, pool, client) + .await; } async fn send_reflection_request( @@ -249,4 +258,137 @@ pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery { .ok_or_else(|| anyhow!("invalid method path")) .expect("invalid method path"); PathAndQuery::from_str(&format!("/{}/{}", namespace, method_name)).expect("invalid method path") -} \ No newline at end of file +} + +mod topology { + use std::collections::{HashMap, HashSet}; + + pub struct SimpleTopoSort { + out_graph: HashMap>, + in_graph: HashMap>, + } + + impl SimpleTopoSort + where + T: Eq + std::hash::Hash + Clone, + { + pub fn new() -> Self { + SimpleTopoSort { + out_graph: HashMap::new(), + in_graph: HashMap::new(), + } + } + + pub fn insert>(&mut self, node: T, deps: I) { + self.out_graph.entry(node.clone()).or_insert(HashSet::new()); + for dep in deps { + self.out_graph.entry(node.clone()).or_insert(HashSet::new()).insert(dep.clone()); + self.in_graph.entry(dep.clone()).or_insert(HashSet::new()).insert(node.clone()); + } + } + } + + impl IntoIterator for SimpleTopoSort + where + T: Eq + std::hash::Hash + Clone, + { + type IntoIter = SimpleTopoSortIter; + type Item = as Iterator>::Item; + + fn into_iter(self) -> Self::IntoIter { + SimpleTopoSortIter::new(self) + } + } + + pub struct SimpleTopoSortIter { + data: SimpleTopoSort, + zero_indegree: Vec, + } + + impl SimpleTopoSortIter + where + T: Eq + std::hash::Hash + Clone, + { + pub fn new(data: SimpleTopoSort) -> Self { + let mut zero_indegree = Vec::new(); + for (node, _) in data.in_graph.iter() { + if !data.out_graph.contains_key(node) { + zero_indegree.push(node.clone()); + } + } + for (node, deps) in data.out_graph.iter(){ + if deps.is_empty(){ + zero_indegree.push(node.clone()); + } + } + + SimpleTopoSortIter { + data, + zero_indegree, + } + } + } + + impl Iterator for SimpleTopoSortIter + where + T: Eq + std::hash::Hash + Clone, + { + type Item = Result; + + fn next(&mut self) -> Option { + if self.zero_indegree.is_empty() { + if self.data.out_graph.is_empty() { + return None; + } + return Some(Err("Cycle detected")); + } + + let node = self.zero_indegree.pop().unwrap(); + if let Some(parents) = self.data.in_graph.get(&node){ + for parent in parents.iter(){ + let deps = self.data.out_graph.get_mut(parent).unwrap(); + deps.remove(&node); + if deps.is_empty() { + self.zero_indegree.push(parent.clone()); + } + } + } + self.data.out_graph.remove(&node); + + Some(Ok(node)) + } + } + + #[test] + fn test_sort(){ + { + let mut topo_sort = SimpleTopoSort::new(); + topo_sort.insert("a", []); + + for node in topo_sort { + match node { + Ok(n) => assert_eq!(n, "a"), + Err(e) => panic!("err {}", e), + } + } + } + + { + let mut topo_sort = SimpleTopoSort::new(); + topo_sort.insert("a", ["b"]); + topo_sort.insert("b", []); + + let mut iter = topo_sort.into_iter(); + match iter.next() { + Some(Ok(n)) => assert_eq!(n, "b"), + _ => panic!("err"), + } + match iter.next() { + Some(Ok(n)) => assert_eq!(n, "a"), + _ => panic!("err"), + } + assert_eq!(iter.next(), None); + } + } + +}