Hooked up test call from frontend!

This commit is contained in:
Gregory Schier
2024-01-29 20:50:43 -08:00
parent 219a6b78da
commit dbdce4cf9a
9 changed files with 821 additions and 9 deletions

View File

@@ -0,0 +1,52 @@
use prost_reflect::prost::Message;
use prost_reflect::{DynamicMessage, MethodDescriptor};
use tonic::codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
use tonic::Status;
#[derive(Clone)]
pub struct DynamicCodec(MethodDescriptor);
impl DynamicCodec {
#[allow(dead_code)]
pub fn new(md: MethodDescriptor) -> Self {
Self(md)
}
}
impl Codec for DynamicCodec {
type Encode = DynamicMessage;
type Decode = DynamicMessage;
type Encoder = Self;
type Decoder = Self;
fn encoder(&mut self) -> Self::Encoder {
self.clone()
}
fn decoder(&mut self) -> Self::Decoder {
self.clone()
}
}
impl Encoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn encode(&mut self, item: Self::Item, dst: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
item.encode(dst)
.expect("buffer is too small to decode this message");
Ok(())
}
}
impl Decoder for DynamicCodec {
type Item = DynamicMessage;
type Error = Status;
fn decode(&mut self, src: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
let mut msg = DynamicMessage::new(self.0.output());
msg.merge(src)
.map_err(|err| Status::internal(err.to_string()))?;
Ok(Some(msg))
}
}

View File

@@ -0,0 +1,175 @@
use std::collections::HashMap;
use prost_reflect::{DescriptorPool, MessageDescriptor};
use prost_types::field_descriptor_proto;
use serde::{Deserialize, Serialize};
#[derive(Default, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct JsonSchemaEntry {
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(rename = "type")]
type_: JsonType,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
properties: Option<HashMap<String, JsonSchemaEntry>>,
#[serde(rename = "enum", skip_serializing_if = "Option::is_none")]
enum_: Option<Vec<String>>,
/// Don't allow any other properties in the object
#[serde(skip_serializing_if = "Option::is_none")]
additional_properties: Option<bool>,
/// Set all properties to required
#[serde(skip_serializing_if = "Option::is_none")]
required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
items: Option<Box<JsonSchemaEntry>>,
}
enum JsonType {
String,
Number,
Object,
Array,
Boolean,
Null,
_UNKNOWN,
}
impl Default for JsonType {
fn default() -> Self {
JsonType::_UNKNOWN
}
}
impl serde::Serialize for JsonType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
JsonType::String => serializer.serialize_str("string"),
JsonType::Number => serializer.serialize_str("number"),
JsonType::Object => serializer.serialize_str("object"),
JsonType::Array => serializer.serialize_str("array"),
JsonType::Boolean => serializer.serialize_str("boolean"),
JsonType::Null => serializer.serialize_str("null"),
JsonType::_UNKNOWN => serializer.serialize_str("unknown"),
}
}
}
impl<'de> serde::Deserialize<'de> for JsonType {
fn deserialize<D>(deserializer: D) -> Result<JsonType, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
match s.as_str() {
"string" => Ok(JsonType::String),
"number" => Ok(JsonType::Number),
"object" => Ok(JsonType::Object),
"array" => Ok(JsonType::Array),
"boolean" => Ok(JsonType::Boolean),
"null" => Ok(JsonType::Null),
_ => Ok(JsonType::_UNKNOWN),
}
}
}
pub fn message_to_json_schema(pool: &DescriptorPool, message: MessageDescriptor) -> JsonSchemaEntry {
let mut schema = JsonSchemaEntry {
title: Some(message.name().to_string()),
type_: JsonType::Object, // Messages are objects
..Default::default()
};
let mut properties = HashMap::new();
message.fields().for_each(|f| match f.kind() {
prost_reflect::Kind::Message(m) => {
properties.insert(f.name().to_string(), message_to_json_schema(pool, m));
}
prost_reflect::Kind::Enum(e) => {
properties.insert(
f.name().to_string(),
JsonSchemaEntry {
type_: map_proto_type_to_json_type(f.field_descriptor_proto().r#type()),
enum_: Some(e.values().map(|v| v.name().to_string()).collect::<Vec<_>>()),
..Default::default()
},
);
}
_ => {
// TODO: Handle repeated label
match f.field_descriptor_proto().label() {
field_descriptor_proto::Label::Repeated => {
// TODO: Handle more complex repeated types. This just handles primitives for now
properties.insert(
f.name().to_string(),
JsonSchemaEntry {
type_: JsonType::Array,
items: Some(Box::new(JsonSchemaEntry {
type_: map_proto_type_to_json_type(
f.field_descriptor_proto().r#type(),
),
..Default::default()
})),
..Default::default()
},
);
}
_ => {
// Regular JSON field
properties.insert(
f.name().to_string(),
JsonSchemaEntry {
type_: map_proto_type_to_json_type(f.field_descriptor_proto().r#type()),
..Default::default()
},
);
}
};
}
});
schema.properties = Some(properties);
schema.required = Some(
message
.fields()
.map(|f| f.name().to_string())
.collect::<Vec<_>>(),
);
schema
}
fn map_proto_type_to_json_type(proto_type: field_descriptor_proto::Type) -> JsonType {
match proto_type {
field_descriptor_proto::Type::Double => JsonType::Number,
field_descriptor_proto::Type::Float => JsonType::Number,
field_descriptor_proto::Type::Int64 => JsonType::Number,
field_descriptor_proto::Type::Uint64 => JsonType::Number,
field_descriptor_proto::Type::Int32 => JsonType::Number,
field_descriptor_proto::Type::Fixed64 => JsonType::Number,
field_descriptor_proto::Type::Fixed32 => JsonType::Number,
field_descriptor_proto::Type::Bool => JsonType::Boolean,
field_descriptor_proto::Type::String => JsonType::String,
field_descriptor_proto::Type::Group => JsonType::_UNKNOWN,
field_descriptor_proto::Type::Message => JsonType::Object,
field_descriptor_proto::Type::Bytes => JsonType::String,
field_descriptor_proto::Type::Uint32 => JsonType::Number,
field_descriptor_proto::Type::Enum => JsonType::String,
field_descriptor_proto::Type::Sfixed32 => JsonType::Number,
field_descriptor_proto::Type::Sfixed64 => JsonType::Number,
field_descriptor_proto::Type::Sint32 => JsonType::Number,
field_descriptor_proto::Type::Sint64 => JsonType::Number,
}
}

79
src-tauri/grpc/src/lib.rs Normal file
View File

@@ -0,0 +1,79 @@
use prost_reflect::DynamicMessage;
use serde::{Deserialize, Serialize};
use serde_json::Deserializer;
use tonic::IntoRequest;
use tonic::transport::Uri;
use crate::codec::DynamicCodec;
use crate::proto::{fill_pool, method_desc_to_path};
mod codec;
mod json_schema;
mod proto;
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct ServiceDefinition {
pub name: String,
pub methods: Vec<MethodDefinition>,
}
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct MethodDefinition {
pub name: String,
pub schema: String,
}
pub async fn call(uri: &Uri, service: &str, method: &str, message_json: &str) -> String {
let (pool, conn) = fill_pool(uri).await;
let service = pool.get_service_by_name(service).unwrap();
let method = &service.methods().find(|m| m.name() == method).unwrap();
let input_message = method.input();
let mut deserializer = Deserializer::from_str(message_json);
let req_message = DynamicMessage::deserialize(input_message, &mut deserializer).unwrap();
deserializer.end().unwrap();
let mut client = tonic::client::Grpc::new(conn);
println!(
"\n---------- SENDING -----------------\n{}",
serde_json::to_string_pretty(&req_message).expect("json")
);
let req = req_message.into_request();
let path = method_desc_to_path(method);
let codec = DynamicCodec::new(method.clone());
client.ready().await.unwrap();
let resp = client.unary(req, path, codec).await.unwrap();
let response_json = serde_json::to_string_pretty(&resp.into_inner()).expect("json to string");
println!("\n---------- RECEIVING ---------------\n{}", response_json,);
response_json
}
pub async fn callable(uri: &Uri) -> Vec<ServiceDefinition> {
let (pool, _) = fill_pool(uri).await;
pool.services()
.map(|s| {
let mut def = ServiceDefinition {
name: s.full_name().to_string(),
..Default::default()
};
for method in s.methods() {
let input_message = method.input();
def.methods.push(MethodDefinition {
name: method.name().to_string(),
schema: serde_json::to_string_pretty(&json_schema::message_to_json_schema(
&pool,
input_message,
))
.unwrap(),
})
}
def
})
.collect::<Vec<_>>()
}

137
src-tauri/grpc/src/proto.rs Normal file
View File

@@ -0,0 +1,137 @@
use std::ops::Deref;
use std::str::FromStr;
use anyhow::anyhow;
use prost::Message;
use prost_reflect::{DescriptorPool, MethodDescriptor};
use prost_types::FileDescriptorProto;
use tokio_stream::StreamExt;
use tonic::codegen::http::uri::PathAndQuery;
use tonic::Request;
use tonic::transport::{Channel, Uri};
use tonic_reflection::pb::server_reflection_client::ServerReflectionClient;
use tonic_reflection::pb::server_reflection_request::MessageRequest;
use tonic_reflection::pb::server_reflection_response::MessageResponse;
use tonic_reflection::pb::ServerReflectionRequest;
pub async fn fill_pool(uri: &Uri) -> (DescriptorPool, Channel) {
let mut pool = DescriptorPool::new();
let conn = tonic::transport::Endpoint::new(uri.clone())
.unwrap()
.connect()
.await
.unwrap();
let mut client = ServerReflectionClient::new(conn.clone());
let services = list_services(&mut client).await;
for service in services {
if service == "grpc.reflection.v1alpha.ServerReflection" {
continue;
}
file_descriptor_set_from_service_name(&service, &mut pool, &mut client).await;
}
(pool, conn)
}
async fn list_services(reflect_client: &mut ServerReflectionClient<Channel>) -> Vec<String> {
let response =
send_reflection_request(reflect_client, MessageRequest::ListServices("".into())).await;
let list_services_response = match response {
MessageResponse::ListServicesResponse(resp) => resp,
_ => panic!("Expected a ListServicesResponse variant"),
};
list_services_response
.service
.iter()
.map(|s| s.name.clone())
.collect::<Vec<_>>()
}
async fn file_descriptor_set_from_service_name(
service_name: &str,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Channel>,
) {
let response = send_reflection_request(
client,
MessageRequest::FileContainingSymbol(service_name.into()),
)
.await;
let file_descriptor_response = match response {
MessageResponse::FileDescriptorResponse(resp) => resp,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
for fd in file_descriptor_response.file_descriptor_proto {
let fdp = FileDescriptorProto::decode(fd.deref()).unwrap();
// Add deps first or else we'll get an error
for dep_name in fdp.clone().dependency {
file_descriptor_set_by_filename(&dep_name, pool, client).await;
}
pool.add_file_descriptor_proto(fdp)
.expect("add file descriptor proto");
}
}
async fn file_descriptor_set_by_filename(
filename: &str,
pool: &mut DescriptorPool,
client: &mut ServerReflectionClient<Channel>,
) {
// We already fetched this file
if let Some(_) = pool.get_file_by_name(filename) {
return;
}
let response =
send_reflection_request(client, MessageRequest::FileByFilename(filename.into())).await;
let file_descriptor_response = match response {
MessageResponse::FileDescriptorResponse(resp) => resp,
_ => panic!("Expected a FileDescriptorResponse variant"),
};
for fd in file_descriptor_response.file_descriptor_proto {
let fdp = FileDescriptorProto::decode(fd.deref()).unwrap();
pool.add_file_descriptor_proto(fdp)
.expect("add file descriptor proto");
}
}
async fn send_reflection_request(
client: &mut ServerReflectionClient<Channel>,
message: MessageRequest,
) -> MessageResponse {
let reflection_request = ServerReflectionRequest {
host: "".into(), // Doesn't matter
message_request: Some(message),
};
let request = Request::new(tokio_stream::once(reflection_request));
client
.server_reflection_info(request)
.await
.expect("server reflection failed")
.into_inner()
.next()
.await
.expect("steamed response")
.expect("successful response")
.message_response
.expect("some MessageResponse")
}
pub fn method_desc_to_path(md: &MethodDescriptor) -> PathAndQuery {
let full_name = md.full_name();
let (namespace, method_name) = full_name
.rsplit_once('.')
.ok_or_else(|| anyhow!("invalid method path"))
.expect("invalid method path");
PathAndQuery::from_str(&format!("/{}/{}", namespace, method_name)).expect("invalid method path")
}