From 42e70b941d7f918b5d0a4dc8aff8de583045ec65 Mon Sep 17 00:00:00 2001 From: Hao Xiang Date: Sat, 17 May 2025 03:53:53 +0800 Subject: [PATCH] fix proto to json-schema (#194) --- src-tauri/yaak-grpc/src/json_schema.rs | 475 ++++++++++++++++++------- 1 file changed, 351 insertions(+), 124 deletions(-) diff --git a/src-tauri/yaak-grpc/src/json_schema.rs b/src-tauri/yaak-grpc/src/json_schema.rs index b915c6c0..685b8e11 100644 --- a/src-tauri/yaak-grpc/src/json_schema.rs +++ b/src-tauri/yaak-grpc/src/json_schema.rs @@ -1,16 +1,256 @@ -use prost_reflect::{DescriptorPool, MessageDescriptor}; -use prost_types::field_descriptor_proto; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use prost_reflect::{DescriptorPool, FieldDescriptor, MessageDescriptor}; +use std::collections::{HashMap, HashSet, VecDeque}; -#[derive(Default, Serialize, Deserialize)] +pub fn message_to_json_schema(_: &DescriptorPool, root_msg: MessageDescriptor) -> JsonSchemaEntry { + JsonSchemaGenerator::generate_json_schema(root_msg) +} + +struct JsonSchemaGenerator { + msg_mapping: HashMap, +} + +impl JsonSchemaGenerator { + pub fn new() -> Self { + JsonSchemaGenerator { + msg_mapping: HashMap::new(), + } + } + + pub fn generate_json_schema(msg: MessageDescriptor) -> JsonSchemaEntry { + let generator = JsonSchemaGenerator::new(); + generator.scan_root(msg) + } + + fn add_message(&mut self, msg: &MessageDescriptor) { + let name = msg.full_name().to_string(); + if self.msg_mapping.contains_key(&name) { + return; + } + self.msg_mapping.insert(name.clone(), JsonSchemaEntry::object()); + } + + pub fn scan_root(mut self, root_msg: MessageDescriptor) -> JsonSchemaEntry { + self.init_structure(root_msg.clone()); + self.fill_properties(root_msg.clone()); + + let mut root = self.msg_mapping.remove(root_msg.full_name()).unwrap(); + + if self.msg_mapping.len() > 0 { + root.defs = Some(self.msg_mapping); + } + root + } + + fn fill_properties(&mut self, root_msg: MessageDescriptor) { + let root_name = root_msg.full_name().to_string(); + + let mut visited = HashSet::new(); + let mut msg_queue = VecDeque::new(); + msg_queue.push_back(root_msg); + + while !msg_queue.is_empty() { + let msg = msg_queue.pop_front().unwrap(); + let msg_name = msg.full_name(); + if visited.contains(msg_name) { + continue; + } + + visited.insert(msg_name.to_string()); + + let entry = self.msg_mapping.get_mut(msg_name).unwrap(); + + for field in msg.fields() { + let field_name = field.name().to_string(); + + if matches!(field.cardinality(), prost_reflect::Cardinality::Required) { + entry.add_required(field_name.clone()); + } + + if let Some(oneof) = field.containing_oneof() { + for oneof_field in oneof.fields() { + if let Some(fm) = is_message_field(&oneof_field) { + msg_queue.push_back(fm); + } + entry.add_property( + oneof_field.name().to_string(), + field_to_type_or_ref(&root_name, oneof_field), + ); + } + continue; + } + + let (field_type, nest_msg) = { + if let Some(fm) = is_message_field(&field) { + if field.is_list() { + // repeated message type + ( + JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)), + Some(fm), + ) + } else if field.is_map() { + let value_field = fm.get_field_by_name("value").unwrap(); + + if let Some(fm) = is_message_field(&value_field) { + ( + JsonSchemaEntry::map(field_to_type_or_ref( + &root_name, + value_field, + )), + Some(fm), + ) + } else { + ( + JsonSchemaEntry::map(field_to_type_or_ref( + &root_name, + value_field, + )), + None, + ) + } + } else { + (field_to_type_or_ref(&root_name, field), Some(fm)) + } + } else { + if field.is_list() { + // repeated scalar type + (JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)), None) + } else { + (field_to_type_or_ref(&root_name, field), None) + } + } + }; + + if let Some(fm) = nest_msg { + msg_queue.push_back(fm); + } + + entry.add_property(field_name, field_type); + } + } + } + + fn init_structure(&mut self, root_msg: MessageDescriptor) { + let mut visited = HashSet::new(); + let mut msg_queue = VecDeque::new(); + msg_queue.push_back(root_msg.clone()); + + // level traversal, to make sure all message type is defined before used + while !msg_queue.is_empty() { + let msg = msg_queue.pop_front().unwrap(); + let name = msg.full_name(); + if visited.contains(name) { + continue; + } + visited.insert(name.to_string()); + self.add_message(&msg); + + for child in msg.child_messages() { + if child.is_map_entry() { + // for field with map type, there will be a child message type *Entry generated + // just skip it + continue; + } + + self.add_message(&child); + msg_queue.push_back(child); + } + + for field in msg.fields() { + if let Some(oneof) = field.containing_oneof() { + for oneof_field in oneof.fields() { + if let Some(fm) = is_message_field(&oneof_field) { + self.add_message(&fm); + msg_queue.push_back(fm); + } + } + continue; + } + if field.is_map() { + // key is always scalar type, so no need to process + // value can be any type, so need to unpack value type + let map_field_msg = is_message_field(&field).unwrap(); + let map_value_field = map_field_msg.get_field_by_name("value").unwrap(); + if let Some(value_fm) = is_message_field(&map_value_field) { + self.add_message(&value_fm); + msg_queue.push_back(value_fm); + } + continue; + } + if let Some(fm) = is_message_field(&field) { + self.add_message(&fm); + msg_queue.push_back(fm); + } + } + } + } +} + +fn field_to_type_or_ref(root_name: &str, field: FieldDescriptor) -> JsonSchemaEntry { + match field.kind() { + prost_reflect::Kind::Bool => JsonSchemaEntry::boolean(), + prost_reflect::Kind::Double => JsonSchemaEntry::number("double"), + prost_reflect::Kind::Float => JsonSchemaEntry::number("float"), + prost_reflect::Kind::Int32 => JsonSchemaEntry::number("int32"), + prost_reflect::Kind::Int64 => JsonSchemaEntry::string_with_format("int64"), + prost_reflect::Kind::Uint32 => JsonSchemaEntry::number("int64"), + prost_reflect::Kind::Uint64 => JsonSchemaEntry::string_with_format("uint64"), + prost_reflect::Kind::Sint32 => JsonSchemaEntry::number("sint32"), + prost_reflect::Kind::Sint64 => JsonSchemaEntry::string_with_format("sint64"), + prost_reflect::Kind::Fixed32 => JsonSchemaEntry::number("int64"), + prost_reflect::Kind::Fixed64 => JsonSchemaEntry::string_with_format("fixed64"), + prost_reflect::Kind::Sfixed32 => JsonSchemaEntry::number("sfixed32"), + prost_reflect::Kind::Sfixed64 => JsonSchemaEntry::string_with_format("sfixed64"), + prost_reflect::Kind::String => JsonSchemaEntry::string(), + prost_reflect::Kind::Bytes => JsonSchemaEntry::string_with_format("byte"), + prost_reflect::Kind::Enum(enums) => { + let values = enums.values().map(|v| v.name().to_string()).collect::>(); + JsonSchemaEntry::enums(values) + } + prost_reflect::Kind::Message(fm) => { + let field_type_full_name = fm.full_name(); + match field_type_full_name { + // [Protocol Buffers Well-Known Types]: https://protobuf.dev/reference/protobuf/google.protobuf/ + "google.protobuf.FieldMask" => JsonSchemaEntry::string(), + "google.protobuf.Timestamp" => JsonSchemaEntry::string_with_format("date-time"), + "google.protobuf.Duration" => JsonSchemaEntry::string(), + "google.protobuf.StringValue" => JsonSchemaEntry::string(), + "google.protobuf.BytesValue" => JsonSchemaEntry::string_with_format("byte"), + "google.protobuf.Int32Value" => JsonSchemaEntry::number("int32"), + "google.protobuf.UInt32Value" => JsonSchemaEntry::string_with_format("int64"), + "google.protobuf.Int64Value" => JsonSchemaEntry::string_with_format("int64"), + "google.protobuf.UInt64Value" => JsonSchemaEntry::string_with_format("uint64"), + "google.protobuf.FloatValue" => JsonSchemaEntry::number("float"), + "google.protobuf.DoubleValue" => JsonSchemaEntry::number("double"), + "google.protobuf.BoolValue" => JsonSchemaEntry::boolean(), + "google.protobuf.Empty" => JsonSchemaEntry::default(), + "google.protobuf.Struct" => JsonSchemaEntry::object(), + "google.protobuf.ListValue" => JsonSchemaEntry::array(JsonSchemaEntry::default()), + "google.protobuf.NullValue" => JsonSchemaEntry::null(), + name @ _ if name == root_name => JsonSchemaEntry::root_reference(), + _ => JsonSchemaEntry::reference(fm.full_name()), + } + } + } +} + +fn is_message_field(field: &FieldDescriptor) -> Option { + match field.kind() { + prost_reflect::Kind::Message(m) => Some(m), + _ => None, + } +} + +#[derive(Default, serde::Serialize)] #[serde(default, rename_all = "camelCase")] pub struct JsonSchemaEntry { #[serde(skip_serializing_if = "Option::is_none")] title: Option, - #[serde(rename = "type")] - type_: JsonType, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + type_: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + format: Option, #[serde(skip_serializing_if = "Option::is_none")] description: Option, @@ -21,15 +261,115 @@ pub struct JsonSchemaEntry { #[serde(rename = "enum", skip_serializing_if = "Option::is_none")] enum_: Option>, - /// Don't allow any other properties in the object - additional_properties: bool, + // for map type + #[serde(skip_serializing_if = "Option::is_none")] + additional_properties: Option>, - /// Set all properties to required + // Set all properties to required #[serde(skip_serializing_if = "Option::is_none")] required: Option>, #[serde(skip_serializing_if = "Option::is_none")] items: Option>, + + #[serde(skip_serializing_if = "Option::is_none", rename = "$defs")] + defs: Option>, + + #[serde(skip_serializing_if = "Option::is_none", rename = "$ref")] + ref_: Option, +} + +impl JsonSchemaEntry { + pub fn add_property(&mut self, name: String, entry: JsonSchemaEntry) { + if self.properties.is_none() { + self.properties = Some(HashMap::new()); + } + self.properties.as_mut().unwrap().insert(name, entry); + } + + pub fn add_required(&mut self, name: String) { + if self.required.is_none() { + self.required = Some(Vec::new()); + } + self.required.as_mut().unwrap().push(name); + } +} + +impl JsonSchemaEntry { + pub fn object() -> Self { + JsonSchemaEntry { + type_: Some(JsonType::Object), + ..Default::default() + } + } + pub fn boolean() -> Self { + JsonSchemaEntry { + type_: Some(JsonType::Boolean), + ..Default::default() + } + } + pub fn number>(format: S) -> Self { + JsonSchemaEntry { + type_: Some(JsonType::Number), + format: Some(format.into()), + ..Default::default() + } + } + pub fn string() -> Self { + JsonSchemaEntry { + type_: Some(JsonType::String), + ..Default::default() + } + } + + pub fn string_with_format>(format: S) -> Self { + JsonSchemaEntry { + type_: Some(JsonType::String), + format: Some(format.into()), + ..Default::default() + } + } + pub fn reference>(ref_: S) -> Self { + JsonSchemaEntry { + ref_: Some(format!("#/$defs/{}", ref_.as_ref())), + ..Default::default() + } + } + pub fn root_reference() -> Self{ + JsonSchemaEntry { + ref_: Some("#".to_string()), + ..Default::default() + } + } + pub fn array(item: JsonSchemaEntry) -> Self { + JsonSchemaEntry { + type_: Some(JsonType::Array), + items: Some(Box::new(item)), + ..Default::default() + } + } + pub fn enums(enums: Vec) -> Self { + JsonSchemaEntry { + type_: Some(JsonType::String), + enum_: Some(enums), + ..Default::default() + } + } + + pub fn map(value_type: JsonSchemaEntry) -> Self { + JsonSchemaEntry { + type_: Some(JsonType::Object), + additional_properties: Some(Box::new(value_type)), + ..Default::default() + } + } + + pub fn null() -> Self { + JsonSchemaEntry { + type_: Some(JsonType::Null), + ..Default::default() + } + } } enum JsonType { @@ -49,7 +389,7 @@ impl Default for JsonType { } impl serde::Serialize for JsonType { - fn serialize(&self, serializer: S) -> Result + fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { @@ -64,116 +404,3 @@ impl serde::Serialize for JsonType { } } } - -impl<'de> serde::Deserialize<'de> for JsonType { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - match s.as_str() { - "string" => Ok(JsonType::String), - "number" => Ok(JsonType::Number), - "object" => Ok(JsonType::Object), - "array" => Ok(JsonType::Array), - "boolean" => Ok(JsonType::Boolean), - "null" => Ok(JsonType::Null), - _ => Ok(JsonType::_UNKNOWN), - } - } -} - -pub fn message_to_json_schema( - pool: &DescriptorPool, - message: MessageDescriptor, -) -> JsonSchemaEntry { - let mut schema = JsonSchemaEntry { - title: Some(message.name().to_string()), - type_: JsonType::Object, // Messages are objects - ..Default::default() - }; - - let mut properties = HashMap::new(); - message.fields().for_each(|f| match f.kind() { - prost_reflect::Kind::Message(m) => { - properties.insert(f.name().to_string(), message_to_json_schema(pool, m)); - } - prost_reflect::Kind::Enum(e) => { - properties.insert( - f.name().to_string(), - JsonSchemaEntry { - type_: map_proto_type_to_json_type(f.field_descriptor_proto().r#type()), - enum_: Some(e.values().map(|v| v.name().to_string()).collect::>()), - ..Default::default() - }, - ); - } - _ => { - // TODO: Handle repeated label - match f.field_descriptor_proto().label() { - field_descriptor_proto::Label::Repeated => { - // TODO: Handle more complex repeated types. This just handles primitives for now - properties.insert( - f.name().to_string(), - JsonSchemaEntry { - type_: JsonType::Array, - items: Some(Box::new(JsonSchemaEntry { - type_: map_proto_type_to_json_type( - f.field_descriptor_proto().r#type(), - ), - ..Default::default() - })), - ..Default::default() - }, - ); - } - _ => { - // Regular JSON field - properties.insert( - f.name().to_string(), - JsonSchemaEntry { - type_: map_proto_type_to_json_type(f.field_descriptor_proto().r#type()), - ..Default::default() - }, - ); - } - }; - } - }); - - schema.properties = Some(properties); - - // All proto 3 fields are optional, so maybe we could - // make this a setting? - // schema.required = Some( - // message - // .fields() - // .map(|f| f.name().to_string()) - // .collect::>(), - // ); - - schema -} - -fn map_proto_type_to_json_type(proto_type: field_descriptor_proto::Type) -> JsonType { - match proto_type { - field_descriptor_proto::Type::Double => JsonType::Number, - field_descriptor_proto::Type::Float => JsonType::Number, - field_descriptor_proto::Type::Int64 => JsonType::Number, - field_descriptor_proto::Type::Uint64 => JsonType::Number, - field_descriptor_proto::Type::Int32 => JsonType::Number, - field_descriptor_proto::Type::Fixed64 => JsonType::Number, - field_descriptor_proto::Type::Fixed32 => JsonType::Number, - field_descriptor_proto::Type::Bool => JsonType::Boolean, - field_descriptor_proto::Type::String => JsonType::String, - field_descriptor_proto::Type::Group => JsonType::_UNKNOWN, - field_descriptor_proto::Type::Message => JsonType::Object, - field_descriptor_proto::Type::Bytes => JsonType::String, - field_descriptor_proto::Type::Uint32 => JsonType::Number, - field_descriptor_proto::Type::Enum => JsonType::String, - field_descriptor_proto::Type::Sfixed32 => JsonType::Number, - field_descriptor_proto::Type::Sfixed64 => JsonType::Number, - field_descriptor_proto::Type::Sint32 => JsonType::Number, - field_descriptor_proto::Type::Sint64 => JsonType::Number, - } -}