use prost_reflect::{DescriptorPool, FieldDescriptor, MessageDescriptor}; use std::collections::{HashMap, HashSet, VecDeque}; pub fn message_to_json_schema(_: &DescriptorPool, root_msg: MessageDescriptor) -> JsonSchemaEntry { JsonSchemaGenerator::generate_json_schema(root_msg) } struct JsonSchemaGenerator { msg_mapping: HashMap, } impl JsonSchemaGenerator { pub fn new() -> Self { JsonSchemaGenerator { msg_mapping: HashMap::new(), } } pub fn generate_json_schema(msg: MessageDescriptor) -> JsonSchemaEntry { let generator = JsonSchemaGenerator::new(); generator.scan_root(msg) } fn add_message(&mut self, msg: &MessageDescriptor) { let name = msg.full_name().to_string(); if self.msg_mapping.contains_key(&name) { return; } self.msg_mapping.insert(name.clone(), JsonSchemaEntry::object()); } pub fn scan_root(mut self, root_msg: MessageDescriptor) -> JsonSchemaEntry { self.init_structure(root_msg.clone()); self.fill_properties(root_msg.clone()); let mut root = self.msg_mapping.remove(root_msg.full_name()).unwrap(); if self.msg_mapping.len() > 0 { root.defs = Some(self.msg_mapping); } root } fn fill_properties(&mut self, root_msg: MessageDescriptor) { let root_name = root_msg.full_name().to_string(); let mut visited = HashSet::new(); let mut msg_queue = VecDeque::new(); msg_queue.push_back(root_msg); while !msg_queue.is_empty() { let msg = msg_queue.pop_front().unwrap(); let msg_name = msg.full_name(); if visited.contains(msg_name) { continue; } visited.insert(msg_name.to_string()); let entry = self.msg_mapping.get_mut(msg_name).unwrap(); for field in msg.fields() { let field_name = field.name().to_string(); if matches!(field.cardinality(), prost_reflect::Cardinality::Required) { entry.add_required(field_name.clone()); } if let Some(oneof) = field.containing_oneof() { for oneof_field in oneof.fields() { if let Some(fm) = is_message_field(&oneof_field) { msg_queue.push_back(fm); } entry.add_property( oneof_field.name().to_string(), field_to_type_or_ref(&root_name, oneof_field), ); } continue; } let (field_type, nest_msg) = { if let Some(fm) = is_message_field(&field) { if field.is_list() { // repeated message type ( JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)), Some(fm), ) } else if field.is_map() { let value_field = fm.get_field_by_name("value").unwrap(); if let Some(fm) = is_message_field(&value_field) { ( JsonSchemaEntry::map(field_to_type_or_ref( &root_name, value_field, )), Some(fm), ) } else { ( JsonSchemaEntry::map(field_to_type_or_ref( &root_name, value_field, )), None, ) } } else { (field_to_type_or_ref(&root_name, field), Some(fm)) } } else { if field.is_list() { // repeated scalar type (JsonSchemaEntry::array(field_to_type_or_ref(&root_name, field)), None) } else { (field_to_type_or_ref(&root_name, field), None) } } }; if let Some(fm) = nest_msg { msg_queue.push_back(fm); } entry.add_property(field_name, field_type); } } } fn init_structure(&mut self, root_msg: MessageDescriptor) { let mut visited = HashSet::new(); let mut msg_queue = VecDeque::new(); msg_queue.push_back(root_msg.clone()); // level traversal, to make sure all message type is defined before used while !msg_queue.is_empty() { let msg = msg_queue.pop_front().unwrap(); let name = msg.full_name(); if visited.contains(name) { continue; } visited.insert(name.to_string()); self.add_message(&msg); for child in msg.child_messages() { if child.is_map_entry() { // for field with map type, there will be a child message type *Entry generated // just skip it continue; } self.add_message(&child); msg_queue.push_back(child); } for field in msg.fields() { if let Some(oneof) = field.containing_oneof() { for oneof_field in oneof.fields() { if let Some(fm) = is_message_field(&oneof_field) { self.add_message(&fm); msg_queue.push_back(fm); } } continue; } if field.is_map() { // key is always scalar type, so no need to process // value can be any type, so need to unpack value type let map_field_msg = is_message_field(&field).unwrap(); let map_value_field = map_field_msg.get_field_by_name("value").unwrap(); if let Some(value_fm) = is_message_field(&map_value_field) { self.add_message(&value_fm); msg_queue.push_back(value_fm); } continue; } if let Some(fm) = is_message_field(&field) { self.add_message(&fm); msg_queue.push_back(fm); } } } } } fn field_to_type_or_ref(root_name: &str, field: FieldDescriptor) -> JsonSchemaEntry { match field.kind() { prost_reflect::Kind::Bool => JsonSchemaEntry::boolean(), prost_reflect::Kind::Double => JsonSchemaEntry::number("double"), prost_reflect::Kind::Float => JsonSchemaEntry::number("float"), prost_reflect::Kind::Int32 => JsonSchemaEntry::number("int32"), prost_reflect::Kind::Int64 => JsonSchemaEntry::string_with_format("int64"), prost_reflect::Kind::Uint32 => JsonSchemaEntry::number("int64"), prost_reflect::Kind::Uint64 => JsonSchemaEntry::string_with_format("uint64"), prost_reflect::Kind::Sint32 => JsonSchemaEntry::number("sint32"), prost_reflect::Kind::Sint64 => JsonSchemaEntry::string_with_format("sint64"), prost_reflect::Kind::Fixed32 => JsonSchemaEntry::number("int64"), prost_reflect::Kind::Fixed64 => JsonSchemaEntry::string_with_format("fixed64"), prost_reflect::Kind::Sfixed32 => JsonSchemaEntry::number("sfixed32"), prost_reflect::Kind::Sfixed64 => JsonSchemaEntry::string_with_format("sfixed64"), prost_reflect::Kind::String => JsonSchemaEntry::string(), prost_reflect::Kind::Bytes => JsonSchemaEntry::string_with_format("byte"), prost_reflect::Kind::Enum(enums) => { let values = enums.values().map(|v| v.name().to_string()).collect::>(); JsonSchemaEntry::enums(values) } prost_reflect::Kind::Message(fm) => { let field_type_full_name = fm.full_name(); match field_type_full_name { // [Protocol Buffers Well-Known Types]: https://protobuf.dev/reference/protobuf/google.protobuf/ "google.protobuf.FieldMask" => JsonSchemaEntry::string(), "google.protobuf.Timestamp" => JsonSchemaEntry::string_with_format("date-time"), "google.protobuf.Duration" => JsonSchemaEntry::string(), "google.protobuf.StringValue" => JsonSchemaEntry::string(), "google.protobuf.BytesValue" => JsonSchemaEntry::string_with_format("byte"), "google.protobuf.Int32Value" => JsonSchemaEntry::number("int32"), "google.protobuf.UInt32Value" => JsonSchemaEntry::string_with_format("int64"), "google.protobuf.Int64Value" => JsonSchemaEntry::string_with_format("int64"), "google.protobuf.UInt64Value" => JsonSchemaEntry::string_with_format("uint64"), "google.protobuf.FloatValue" => JsonSchemaEntry::number("float"), "google.protobuf.DoubleValue" => JsonSchemaEntry::number("double"), "google.protobuf.BoolValue" => JsonSchemaEntry::boolean(), "google.protobuf.Empty" => JsonSchemaEntry::default(), "google.protobuf.Struct" => JsonSchemaEntry::object(), "google.protobuf.ListValue" => JsonSchemaEntry::array(JsonSchemaEntry::default()), "google.protobuf.NullValue" => JsonSchemaEntry::null(), name @ _ if name == root_name => JsonSchemaEntry::root_reference(), _ => JsonSchemaEntry::reference(fm.full_name()), } } } } fn is_message_field(field: &FieldDescriptor) -> Option { match field.kind() { prost_reflect::Kind::Message(m) => Some(m), _ => None, } } #[derive(Default, serde::Serialize)] #[serde(default, rename_all = "camelCase")] pub struct JsonSchemaEntry { #[serde(skip_serializing_if = "Option::is_none")] title: Option, #[serde(rename = "type", skip_serializing_if = "Option::is_none")] type_: Option, #[serde(skip_serializing_if = "Option::is_none")] format: Option, #[serde(skip_serializing_if = "Option::is_none")] description: Option, #[serde(skip_serializing_if = "Option::is_none")] properties: Option>, #[serde(rename = "enum", skip_serializing_if = "Option::is_none")] enum_: Option>, // for map type #[serde(skip_serializing_if = "Option::is_none")] additional_properties: Option>, // Set all properties to required #[serde(skip_serializing_if = "Option::is_none")] required: Option>, #[serde(skip_serializing_if = "Option::is_none")] items: Option>, #[serde(skip_serializing_if = "Option::is_none", rename = "$defs")] defs: Option>, #[serde(skip_serializing_if = "Option::is_none", rename = "$ref")] ref_: Option, } impl JsonSchemaEntry { pub fn add_property(&mut self, name: String, entry: JsonSchemaEntry) { if self.properties.is_none() { self.properties = Some(HashMap::new()); } self.properties.as_mut().unwrap().insert(name, entry); } pub fn add_required(&mut self, name: String) { if self.required.is_none() { self.required = Some(Vec::new()); } self.required.as_mut().unwrap().push(name); } } impl JsonSchemaEntry { pub fn object() -> Self { JsonSchemaEntry { type_: Some(JsonType::Object), ..Default::default() } } pub fn boolean() -> Self { JsonSchemaEntry { type_: Some(JsonType::Boolean), ..Default::default() } } pub fn number>(format: S) -> Self { JsonSchemaEntry { type_: Some(JsonType::Number), format: Some(format.into()), ..Default::default() } } pub fn string() -> Self { JsonSchemaEntry { type_: Some(JsonType::String), ..Default::default() } } pub fn string_with_format>(format: S) -> Self { JsonSchemaEntry { type_: Some(JsonType::String), format: Some(format.into()), ..Default::default() } } pub fn reference>(ref_: S) -> Self { JsonSchemaEntry { ref_: Some(format!("#/$defs/{}", ref_.as_ref())), ..Default::default() } } pub fn root_reference() -> Self{ JsonSchemaEntry { ref_: Some("#".to_string()), ..Default::default() } } pub fn array(item: JsonSchemaEntry) -> Self { JsonSchemaEntry { type_: Some(JsonType::Array), items: Some(Box::new(item)), ..Default::default() } } pub fn enums(enums: Vec) -> Self { JsonSchemaEntry { type_: Some(JsonType::String), enum_: Some(enums), ..Default::default() } } pub fn map(value_type: JsonSchemaEntry) -> Self { JsonSchemaEntry { type_: Some(JsonType::Object), additional_properties: Some(Box::new(value_type)), ..Default::default() } } pub fn null() -> Self { JsonSchemaEntry { type_: Some(JsonType::Null), ..Default::default() } } } enum JsonType { String, Number, Object, Array, Boolean, Null, _UNKNOWN, } impl Default for JsonType { fn default() -> Self { JsonType::_UNKNOWN } } impl serde::Serialize for JsonType { fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { match self { JsonType::String => serializer.serialize_str("string"), JsonType::Number => serializer.serialize_str("number"), JsonType::Object => serializer.serialize_str("object"), JsonType::Array => serializer.serialize_str("array"), JsonType::Boolean => serializer.serialize_str("boolean"), JsonType::Null => serializer.serialize_str("null"), JsonType::_UNKNOWN => serializer.serialize_str("unknown"), } } }