mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-28 10:29:30 +02:00
clippy: adhere to pedantic clippy, uniform test error handling
This commit is contained in:
@@ -62,7 +62,7 @@ pub async fn set_api_key(
|
||||
let api_key = User::set_api_key(&user.id, &state.db).await?;
|
||||
|
||||
// Clear the cache so new requests have access to the user with api key
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
auth.cache_clear_user(user.id.clone());
|
||||
|
||||
// Render the API key section block
|
||||
Ok(TemplateResponse::new_partial(
|
||||
@@ -106,7 +106,7 @@ pub async fn update_timezone(
|
||||
User::update_timezone(&user.id, &form.timezone, &state.db).await?;
|
||||
|
||||
// Clear the cache
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
auth.cache_clear_user(user.id.clone());
|
||||
|
||||
let timezones = TZ_VARIANTS
|
||||
.iter()
|
||||
@@ -141,7 +141,7 @@ pub async fn update_theme(
|
||||
User::update_theme(&user.id, &form.theme, &state.db).await?;
|
||||
|
||||
// Clear the cache
|
||||
auth.cache_clear_user(user.id.to_string());
|
||||
auth.cache_clear_user(user.id.clone());
|
||||
|
||||
let theme_options = vec![
|
||||
Theme::Light.as_str().to_string(),
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_openai::types::ListModelResponse;
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
@@ -37,18 +39,14 @@ pub struct AdminPanelData {
|
||||
current_section: AdminSection,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AdminSection {
|
||||
#[default]
|
||||
Overview,
|
||||
Models,
|
||||
}
|
||||
|
||||
impl Default for AdminSection {
|
||||
fn default() -> Self {
|
||||
Self::Overview
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AdminPanelQuery {
|
||||
@@ -107,10 +105,7 @@ fn checkbox_to_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
match String::deserialize(deserializer) {
|
||||
Ok(string) => Ok(string == "on"),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
String::deserialize(deserializer).map(|s| s == "on")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -219,8 +214,8 @@ pub async fn update_model_settings(
|
||||
if reembedding_needed {
|
||||
info!("Embedding dimensions changed. Spawning background re-embedding task...");
|
||||
|
||||
let db_for_task = state.db.clone();
|
||||
let openai_for_task = state.openai_client.clone();
|
||||
let db_for_task = Arc::clone(&state.db);
|
||||
let openai_for_task = Arc::clone(&state.openai_client);
|
||||
let new_model_for_task = new_settings.embedding_model.clone();
|
||||
let new_dims_for_task = new_settings.embedding_dimensions;
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct SignupParams {
|
||||
pub struct Params {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
pub timezone: String,
|
||||
@@ -39,7 +39,7 @@ pub async fn show_signup_form(
|
||||
pub async fn process_signup_and_show_verification(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSessionType,
|
||||
Form(form): Form<SignupParams>,
|
||||
Form(form): Form<Params>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match User::create_new(
|
||||
form.email,
|
||||
|
||||
@@ -49,6 +49,8 @@ pub struct ChatPageData {
|
||||
conversation: Option<Conversation>,
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn show_initialized_chat(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
@@ -57,14 +59,14 @@ pub async fn show_initialized_chat(
|
||||
let conversation = Conversation::new(user.id.clone(), "Test".to_owned());
|
||||
|
||||
let user_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
conversation.id.clone(),
|
||||
MessageRole::User,
|
||||
form.user_query,
|
||||
None,
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
conversation.id.to_string(),
|
||||
conversation.id.clone(),
|
||||
MessageRole::AI,
|
||||
form.llm_response,
|
||||
Some(form.references),
|
||||
@@ -86,10 +88,9 @@ pub async fn show_initialized_chat(
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -130,12 +131,19 @@ pub async fn show_existing_chat(
|
||||
))
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn new_user_message(
|
||||
Path(conversation_id): Path<String>,
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
}
|
||||
|
||||
let conversation: Conversation = state
|
||||
.db
|
||||
.get_item(&conversation_id)
|
||||
@@ -150,33 +158,34 @@ pub async fn new_user_message(
|
||||
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
}
|
||||
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/streaming_response.html",
|
||||
SSEResponseInitData { user_message },
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Panics if the HX-Push header value cannot be parsed.
|
||||
pub async fn new_chat_user_message(
|
||||
State(state): State<HtmlState>,
|
||||
auth: AuthSession<User, String, SessionSurrealPool<Any>, Surreal<Any>>,
|
||||
Form(form): Form<NewMessageForm>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user = match auth.current_user {
|
||||
Some(user) => user,
|
||||
None => return Ok(Redirect::to("/").into_response()),
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
conversation: Conversation,
|
||||
}
|
||||
|
||||
let Some(user) = auth.current_user else {
|
||||
return Ok(Redirect::to("/").into_response());
|
||||
};
|
||||
|
||||
let conversation = Conversation::new(user.id.clone(), "New chat".to_string());
|
||||
@@ -191,11 +200,6 @@ pub async fn new_chat_user_message(
|
||||
state.db.store_item(user_message.clone()).await?;
|
||||
state.invalidate_conversation_archive_cache(&user.id).await;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SSEResponseInitData {
|
||||
user_message: Message,
|
||||
conversation: Conversation,
|
||||
}
|
||||
let mut response = TemplateResponse::new_template(
|
||||
"chat/new_chat_first_response.html",
|
||||
SSEResponseInitData {
|
||||
@@ -205,10 +209,9 @@ pub async fn new_chat_user_message(
|
||||
)
|
||||
.into_response();
|
||||
|
||||
response.headers_mut().insert(
|
||||
"HX-Push",
|
||||
HeaderValue::from_str(&format!("/chat/{}", conversation.id)).unwrap(),
|
||||
);
|
||||
if let Ok(header_value) = HeaderValue::from_str(&format!("/chat/{}", conversation.id)) {
|
||||
response.headers_mut().insert("HX-Push", header_value);
|
||||
}
|
||||
|
||||
Ok(response.into_response())
|
||||
}
|
||||
|
||||
@@ -53,26 +53,22 @@ fn sse_with_keep_alive(stream: EventStream) -> SseResponse {
|
||||
)
|
||||
}
|
||||
|
||||
// Error handling function
|
||||
fn create_error_stream(message: impl Into<String>) -> EventStream {
|
||||
let message = message.into();
|
||||
stream::once(async move { Ok(Event::default().event("error").data(message)) }).boxed()
|
||||
}
|
||||
|
||||
// Helper function to get message and user
|
||||
async fn get_message_and_user(
|
||||
db: &SurrealDbClient,
|
||||
current_user: Option<User>,
|
||||
message_id: &str,
|
||||
) -> Result<(Message, User, Conversation, Vec<Message>, Option<Message>), SseResponse> {
|
||||
// Check authentication
|
||||
let Some(user) = current_user else {
|
||||
return Err(sse_with_keep_alive(create_error_stream(
|
||||
"You must be signed in to use this feature",
|
||||
)));
|
||||
};
|
||||
|
||||
// Retrieve message
|
||||
let message = match db.get_item::<Message>(message_id).await {
|
||||
Ok(Some(message)) => message,
|
||||
Ok(None) => {
|
||||
@@ -88,7 +84,6 @@ async fn get_message_and_user(
|
||||
}
|
||||
};
|
||||
|
||||
// Get conversation history
|
||||
let (conversation, history) =
|
||||
match Conversation::get_complete_conversation(&message.conversation_id, &user.id, db).await
|
||||
{
|
||||
@@ -209,7 +204,6 @@ pub async fn get_response_stream(
|
||||
auth: AuthSessionType,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> SseResponse {
|
||||
// 1. Authentication and initial data validation
|
||||
let (user_message, user, _conversation, history, existing_ai_response) =
|
||||
match get_message_and_user(&state.db, auth.current_user, ¶ms.message_id).await {
|
||||
Ok((user_message, user, conversation, history, existing_ai_response)) => (
|
||||
@@ -226,9 +220,123 @@ pub async fn get_response_stream(
|
||||
return create_replayed_response_stream(&state, existing_ai_message);
|
||||
}
|
||||
|
||||
// 2. Retrieve knowledge entities
|
||||
let (request, allowed_reference_ids) = match prepare_chat_request(&state, &user_message, &user, &history).await {
|
||||
Ok(result) => result,
|
||||
Err(sse) => return sse,
|
||||
};
|
||||
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
|
||||
build_chat_event_stream(state, openai_stream, &user_message, user.id.clone(), allowed_reference_ids)
|
||||
}
|
||||
|
||||
fn build_chat_event_stream(
|
||||
state: HtmlState,
|
||||
openai_stream: impl Stream<Item = Result<async_openai::types::CreateChatCompletionStreamResponse, async_openai::error::OpenAIError>> + Send + 'static,
|
||||
user_message: &Message,
|
||||
user_id: String,
|
||||
allowed_reference_ids: Vec<String>,
|
||||
) -> SseResponse {
|
||||
let (tx, rx) = channel::<String>(1000);
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
|
||||
spawn_storage_task(Arc::clone(&state.db), rx, tx_final, user_message, user_id, allowed_reference_ids);
|
||||
|
||||
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
|
||||
|
||||
let event_stream = openai_stream
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx.clone();
|
||||
let json_state = Arc::clone(&json_state);
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !content.is_empty() {
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
let mut state = json_state.lock().await;
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
yield Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {e}")));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.chain(stream::once(async move {
|
||||
#[derive(Serialize)]
|
||||
struct LocalReferenceData {
|
||||
message: Message,
|
||||
}
|
||||
|
||||
if let Some(message) = rx_final.recv().await {
|
||||
if message
|
||||
.references
|
||||
.as_ref()
|
||||
.is_some_and(std::vec::Vec::is_empty)
|
||||
{
|
||||
return Ok(Event::default().event("empty"));
|
||||
}
|
||||
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
&Value::from_serialize(LocalReferenceData { message }),
|
||||
) {
|
||||
Ok(html) => Ok(Event::default().event("references").data(html)),
|
||||
Err(_) => Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to render references")),
|
||||
}
|
||||
} else {
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to retrieve references"))
|
||||
}
|
||||
}))
|
||||
.chain(once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
}))
|
||||
.boxed();
|
||||
|
||||
sse_with_keep_alive(event_stream)
|
||||
}
|
||||
|
||||
async fn prepare_chat_request(
|
||||
state: &HtmlState,
|
||||
user_message: &Message,
|
||||
user: &User,
|
||||
history: &[Message],
|
||||
) -> Result<
|
||||
(async_openai::types::CreateChatCompletionRequest, Vec<String>),
|
||||
Sse<Pin<Box<dyn Stream<Item = Result<Event, axum::Error>> + Send>>>,
|
||||
> {
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
@@ -248,59 +356,49 @@ pub async fn get_response_stream(
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_e) => {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to retrieve knowledge"));
|
||||
return Err(Sse::new(create_error_stream("Failed to retrieve knowledge")));
|
||||
}
|
||||
};
|
||||
|
||||
let allowed_reference_ids = collect_reference_ids_from_retrieval(&retrieval_result);
|
||||
|
||||
// 3. Create the OpenAI request with appropriate context format
|
||||
let context_json = match &retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(chunks),
|
||||
let context_json = match retrieval_result {
|
||||
retrieval_pipeline::StrategyOutput::Chunks(chunks) => chunks_to_chat_context(&chunks),
|
||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||
retrieved_entities_to_json(entities)
|
||||
}
|
||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||
// For chat, use chunks from the search result
|
||||
chunks_to_chat_context(&search_result.chunks)
|
||||
}
|
||||
};
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&context_json, &history, &user_message.content);
|
||||
create_user_message_with_history(&context_json, history, &user_message.content);
|
||||
let Ok(settings) = SystemSettings::get_current(&state.db).await else {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to retrieve system settings"));
|
||||
return Err(Sse::new(create_error_stream("Failed to retrieve system settings")));
|
||||
};
|
||||
let Ok(request) = create_chat_request(formatted_user_message, &settings) else {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create chat request"));
|
||||
return Err(Sse::new(create_error_stream("Failed to create chat request")));
|
||||
};
|
||||
|
||||
// 4. Set up the OpenAI stream
|
||||
let openai_stream = match state.openai_client.chat().create_stream(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_e) => {
|
||||
return sse_with_keep_alive(create_error_stream("Failed to create OpenAI stream"));
|
||||
}
|
||||
};
|
||||
Ok((request, allowed_reference_ids))
|
||||
}
|
||||
|
||||
// 5. Create channel for collecting complete response
|
||||
let (tx, mut rx) = channel::<String>(1000);
|
||||
let tx_clone = tx.clone();
|
||||
let (tx_final, mut rx_final) = channel::<Message>(1);
|
||||
fn spawn_storage_task(
|
||||
db_client: Arc<SurrealDbClient>,
|
||||
mut rx: tokio::sync::mpsc::Receiver<String>,
|
||||
tx_final: tokio::sync::mpsc::Sender<Message>,
|
||||
user_message: &Message,
|
||||
user_id: String,
|
||||
allowed_reference_ids: Vec<String>,
|
||||
) {
|
||||
let conversation_id = user_message.conversation_id.clone();
|
||||
|
||||
// 6. Set up the collection task for DB storage
|
||||
let db_client = Arc::clone(&state.db);
|
||||
let user_id = user.id.clone();
|
||||
let allowed_reference_ids = allowed_reference_ids.clone();
|
||||
tokio::spawn(async move {
|
||||
drop(tx); // Close sender when no longer needed
|
||||
|
||||
// Collect full response
|
||||
let mut full_json = String::new();
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
full_json.push_str(&chunk);
|
||||
}
|
||||
|
||||
// Try to extract structured data
|
||||
if let Ok(response) = from_str::<LLMResponseFormat>(&full_json) {
|
||||
let raw_references = extract_reference_strings(&response);
|
||||
let answer = response.answer;
|
||||
@@ -347,7 +445,7 @@ pub async fn get_response_stream(
|
||||
);
|
||||
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
conversation_id,
|
||||
MessageRole::AI,
|
||||
answer,
|
||||
Some(initial_validation.valid_refs),
|
||||
@@ -362,104 +460,11 @@ pub async fn get_response_stream(
|
||||
} else {
|
||||
error!("Failed to parse LLM response as structured format");
|
||||
|
||||
// Fallback - store raw response
|
||||
let ai_message = Message::new(
|
||||
user_message.conversation_id,
|
||||
MessageRole::AI,
|
||||
full_json,
|
||||
None,
|
||||
);
|
||||
let ai_message = Message::new(conversation_id, MessageRole::AI, full_json, None);
|
||||
|
||||
let _ = db_client.store_item(ai_message).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Create a shared state for tracking the JSON parsing
|
||||
let json_state = Arc::new(Mutex::new(StreamParserState::new()));
|
||||
|
||||
// 7. Create the response event stream
|
||||
let event_stream = openai_stream
|
||||
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
|
||||
.map(move |result| {
|
||||
let tx_storage = tx_clone.clone();
|
||||
let json_state = Arc::clone(&json_state);
|
||||
|
||||
stream! {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !content.is_empty() {
|
||||
// Always send raw content to storage
|
||||
let _ = tx_storage.send(content.clone()).await;
|
||||
|
||||
// Process through JSON parser
|
||||
let mut state = json_state.lock().await;
|
||||
let display_content = state.process_chunk(&content);
|
||||
drop(state);
|
||||
if !display_content.is_empty() {
|
||||
yield Ok(Event::default()
|
||||
.event("chat_message")
|
||||
.data(display_content));
|
||||
}
|
||||
// If display_content is empty, don't yield anything
|
||||
}
|
||||
// If content is empty, don't yield anything
|
||||
}
|
||||
Err(e) => {
|
||||
yield Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!("Stream error: {e}")));
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.chain(stream::once(async move {
|
||||
if let Some(message) = rx_final.recv().await {
|
||||
// Don't send any event if references is empty
|
||||
if message
|
||||
.references
|
||||
.as_ref()
|
||||
.is_some_and(std::vec::Vec::is_empty)
|
||||
{
|
||||
return Ok(Event::default().event("empty")); // This event won't be sent
|
||||
}
|
||||
|
||||
// Render template with references
|
||||
match state.templates.render(
|
||||
"chat/reference_list.html",
|
||||
&Value::from_serialize(ReferenceData { message }),
|
||||
) {
|
||||
Ok(html) => {
|
||||
// Return the rendered HTML
|
||||
Ok(Event::default().event("references").data(html))
|
||||
}
|
||||
Err(_) => {
|
||||
// Handle template rendering error
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to render references"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle case where no references were received
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data("Failed to retrieve references"))
|
||||
}
|
||||
}))
|
||||
.chain(once(async {
|
||||
Ok(Event::default()
|
||||
.event("close_stream")
|
||||
.data("Stream complete"))
|
||||
}));
|
||||
|
||||
sse_with_keep_alive(event_stream.boxed())
|
||||
}
|
||||
|
||||
struct StreamParserState {
|
||||
@@ -478,23 +483,18 @@ impl StreamParserState {
|
||||
}
|
||||
|
||||
fn process_chunk(&mut self, chunk: &str) -> String {
|
||||
// Feed all characters into the parser
|
||||
for c in chunk.chars() {
|
||||
let _ = self.parser.add_char(c);
|
||||
}
|
||||
|
||||
// Get the current state of the JSON
|
||||
let json = self.parser.get_result();
|
||||
|
||||
// Check if we're in the answer field
|
||||
if let Some(obj) = json.as_object() {
|
||||
if let Some(answer) = obj.get("answer") {
|
||||
self.in_answer_field = true;
|
||||
|
||||
// Get current answer content
|
||||
let current_content = answer.as_str().unwrap_or_default().to_string();
|
||||
|
||||
// Calculate difference to send only new content
|
||||
if current_content.len() > self.last_answer_content.len() {
|
||||
let new_content = current_content[self.last_answer_content.len()..].to_string();
|
||||
self.last_answer_content = current_content;
|
||||
@@ -503,7 +503,6 @@ impl StreamParserState {
|
||||
}
|
||||
}
|
||||
|
||||
// No new content to return
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,9 @@ use axum::{
|
||||
};
|
||||
pub use chat_handlers::{
|
||||
delete_conversation, new_chat_user_message, new_user_message, patch_conversation_title,
|
||||
reload_sidebar, show_chat_base, show_conversation_editing_title, show_existing_chat,
|
||||
show_initialized_chat,
|
||||
reload_sidebar, show_conversation_editing_title,
|
||||
show_chat_base as show_base, show_existing_chat as show_existing,
|
||||
show_initialized_chat as show_initialized,
|
||||
};
|
||||
use message_response_stream::get_response_stream;
|
||||
use references::show_reference_tooltip;
|
||||
@@ -24,10 +25,10 @@ where
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new()
|
||||
.route("/chat", get(show_chat_base).post(new_chat_user_message))
|
||||
.route("/chat", get(show_base).post(new_chat_user_message))
|
||||
.route(
|
||||
"/chat/{id}",
|
||||
get(show_existing_chat)
|
||||
get(show_existing)
|
||||
.post(new_user_message)
|
||||
.delete(delete_conversation),
|
||||
)
|
||||
@@ -36,7 +37,7 @@ where
|
||||
get(show_conversation_editing_title).patch(patch_conversation_title),
|
||||
)
|
||||
.route("/chat/sidebar", get(reload_sidebar))
|
||||
.route("/initialized-chat", post(show_initialized_chat))
|
||||
.route("/initialized-chat", post(show_initialized))
|
||||
.route("/chat/response-stream", get(get_response_stream))
|
||||
.route("/chat/reference/{id}", get(show_reference_tooltip))
|
||||
}
|
||||
|
||||
@@ -102,13 +102,13 @@ pub async fn show_text_content_edit_form(
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct TextContentEditModal {
|
||||
pub text_content: TextContent,
|
||||
}
|
||||
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"content/edit_text_content_modal.html",
|
||||
TextContentEditModal { text_content },
|
||||
@@ -214,13 +214,14 @@ pub async fn show_content_read_modal(
|
||||
RequireUser(user): RequireUser,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
// Get and validate the text content
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
#[derive(Serialize)]
|
||||
pub struct TextContentReadModalData {
|
||||
pub text_content: TextContent,
|
||||
}
|
||||
|
||||
// Get and validate the text content
|
||||
let text_content = User::get_and_validate_text_content(&id, &user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"content/read_content_modal.html",
|
||||
TextContentReadModalData { text_content },
|
||||
|
||||
@@ -226,7 +226,7 @@ fn summarize_task_content(task: &IngestionTask) -> (String, String) {
|
||||
("Text".to_string(), truncate_summary(text, 80))
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
|
||||
("URL".to_string(), url.to_string())
|
||||
("URL".to_string(), url.clone())
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => {
|
||||
("File".to_string(), file_info.file_name.clone())
|
||||
@@ -248,18 +248,16 @@ pub async fn serve_file(
|
||||
RequireUser(user): RequireUser,
|
||||
Path(file_id): Path<String>,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let file_info = match FileInfo::get_by_id(&file_id, &state.db).await {
|
||||
Ok(info) => info,
|
||||
_ => return Ok(TemplateResponse::not_found().into_response()),
|
||||
let Ok(file_info) = FileInfo::get_by_id(&file_id, &state.db).await else {
|
||||
return Ok(TemplateResponse::not_found().into_response());
|
||||
};
|
||||
|
||||
if file_info.user_id != user.id {
|
||||
return Ok(TemplateResponse::unauthorized().into_response());
|
||||
}
|
||||
|
||||
let stream = match state.storage.get_stream(&file_info.path).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Ok(TemplateResponse::server_error().into_response()),
|
||||
let Ok(stream) = state.storage.get_stream(&file_info.path).await else {
|
||||
return Ok(TemplateResponse::server_error().into_response());
|
||||
};
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{pin::Pin, time::Duration};
|
||||
use std::{pin::Pin, sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
@@ -51,13 +51,13 @@ pub async fn show_ingest_form(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ShowIngestFormData {
|
||||
user_categories: Vec<String>,
|
||||
}
|
||||
|
||||
let user_categories = User::get_user_categories(&user.id, &state.db).await?;
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"ingestion_modal.html",
|
||||
ShowIngestFormData { user_categories },
|
||||
@@ -180,7 +180,7 @@ pub async fn get_task_updates_stream(
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> TaskSse {
|
||||
let task_id = params.task_id.clone();
|
||||
let db = state.db.clone();
|
||||
let db = Arc::clone(&state.db);
|
||||
|
||||
// 1. Check for authenticated user
|
||||
let Some(current_user) = auth.current_user else {
|
||||
@@ -198,7 +198,7 @@ pub async fn get_task_updates_stream(
|
||||
}
|
||||
|
||||
let sse_stream = async_stream::stream! {
|
||||
let mut consecutive_db_errors = 0;
|
||||
let mut consecutive_db_errors: u32 = 0;
|
||||
let max_consecutive_db_errors = 3;
|
||||
|
||||
loop {
|
||||
@@ -263,7 +263,7 @@ pub async fn get_task_updates_stream(
|
||||
}
|
||||
Err(db_err) => {
|
||||
error!("Database error while fetching task '{}': {:?}", task_id, db_err);
|
||||
consecutive_db_errors += 1;
|
||||
consecutive_db_errors = consecutive_db_errors.saturating_add(1);
|
||||
yield Ok(Event::default().event("error").data(format!("Temporary error fetching task update (attempt {consecutive_db_errors}).")));
|
||||
|
||||
if consecutive_db_errors >= max_consecutive_db_errors {
|
||||
|
||||
@@ -39,7 +39,7 @@ use url::form_urlencoded;
|
||||
|
||||
const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12;
|
||||
const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"];
|
||||
const DEFAULT_RELATIONSHIP_TYPE: &str = RELATIONSHIP_TYPE_OPTIONS[0];
|
||||
const DEFAULT_RELATIONSHIP_TYPE: &str = "RelatedTo";
|
||||
const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10;
|
||||
const SUGGESTION_MIN_SCORE: f32 = 0.5;
|
||||
|
||||
@@ -61,15 +61,15 @@ fn canonicalize_relationship_type(value: &str) -> String {
|
||||
|
||||
let key: String = trimmed
|
||||
.chars()
|
||||
.filter(|c| c.is_ascii_alphanumeric())
|
||||
.flat_map(|c| c.to_lowercase())
|
||||
.filter(char::is_ascii_alphanumeric)
|
||||
.flat_map(char::to_lowercase)
|
||||
.collect();
|
||||
|
||||
for option in RELATIONSHIP_TYPE_OPTIONS {
|
||||
let option_key: String = option
|
||||
.chars()
|
||||
.filter(|c| c.is_ascii_alphanumeric())
|
||||
.flat_map(|c| c.to_lowercase())
|
||||
.filter(char::is_ascii_alphanumeric)
|
||||
.flat_map(char::to_lowercase)
|
||||
.collect();
|
||||
if option_key == key {
|
||||
return (*option).to_string();
|
||||
@@ -141,7 +141,7 @@ pub async fn show_new_knowledge_entity_form(
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let entity_types: Vec<String> = KnowledgeEntityType::variants()
|
||||
.iter()
|
||||
.map(|&s| s.to_owned())
|
||||
.map(ToString::to_string)
|
||||
.collect();
|
||||
|
||||
let existing_entities = User::get_knowledge_entities(&user.id, &state.db).await?;
|
||||
@@ -278,7 +278,7 @@ pub async fn suggest_knowledge_relationships(
|
||||
if !query_parts.is_empty() {
|
||||
let query = query_parts.join(" ");
|
||||
let rerank_lease = match state.reranker_pool.as_ref() {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
@@ -406,9 +406,10 @@ fn build_relationship_table_data(
|
||||
.map(|relationship| {
|
||||
let relationship_type_label =
|
||||
canonicalize_relationship_type(&relationship.metadata.relationship_type);
|
||||
*frequency
|
||||
let count = frequency
|
||||
.entry(relationship_type_label.clone())
|
||||
.or_insert(0) += 1;
|
||||
.or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
RelationshipTableRow {
|
||||
relationship,
|
||||
relationship_type_label,
|
||||
@@ -417,9 +418,7 @@ fn build_relationship_table_data(
|
||||
.collect();
|
||||
let default_relationship_type = frequency
|
||||
.into_iter()
|
||||
.max_by_key(|(_, count)| *count)
|
||||
.map(|(label, _)| label)
|
||||
.unwrap_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string());
|
||||
.max_by_key(|(_, count)| *count).map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label);
|
||||
|
||||
RelationshipTableData {
|
||||
entities,
|
||||
@@ -800,8 +799,10 @@ pub async fn get_knowledge_graph_json(
|
||||
for rel in &relationships {
|
||||
if entity_ids.contains(&rel.in_) && entity_ids.contains(&rel.out) {
|
||||
// undirected counting for degree
|
||||
*degree_count.entry(rel.in_.clone()).or_insert(0) += 1;
|
||||
*degree_count.entry(rel.out.clone()).or_insert(0) += 1;
|
||||
let count = degree_count.entry(rel.in_.clone()).or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
let count = degree_count.entry(rel.out.clone()).or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
links.push(GraphLink {
|
||||
source: rel.out.clone(),
|
||||
target: rel.in_.clone(),
|
||||
@@ -836,11 +837,11 @@ fn normalize_filter(input: Option<String>) -> Option<String> {
|
||||
|
||||
fn trim_matching_quotes(value: &str) -> &str {
|
||||
let bytes = value.as_bytes();
|
||||
if bytes.len() >= 2 {
|
||||
let first = bytes[0];
|
||||
let last = bytes[bytes.len() - 1];
|
||||
if (first == b'"' && last == b'"') || (first == b'\'' && last == b'\'') {
|
||||
return &value[1..value.len() - 1];
|
||||
if let (Some(&first), Some(&last)) = (bytes.first(), bytes.last()) {
|
||||
if bytes.len() >= 2
|
||||
&& ((first == b'"' && last == b'"') || (first == b'\'' && last == b'\''))
|
||||
{
|
||||
return &value[1..value.len().saturating_sub(1)];
|
||||
}
|
||||
}
|
||||
value
|
||||
@@ -860,7 +861,7 @@ pub async fn show_edit_knowledge_entity_form(
|
||||
// Get entity types
|
||||
let entity_types: Vec<String> = KnowledgeEntityType::variants()
|
||||
.iter()
|
||||
.map(|&s| s.to_owned())
|
||||
.map(ToString::to_string)
|
||||
.collect();
|
||||
|
||||
// Get the entity and validate ownership
|
||||
|
||||
@@ -11,6 +11,7 @@ use axum::{
|
||||
use common::storage::types::{
|
||||
serde_helpers::deserialize_flexible_id,
|
||||
text_content::TextContent,
|
||||
user::User,
|
||||
StoredObject,
|
||||
};
|
||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
||||
@@ -46,13 +47,11 @@ fn source_id_suffix(source_id: &str) -> String {
|
||||
|
||||
fn truncate_label(value: &str, max_chars: usize) -> String {
|
||||
let mut end = None;
|
||||
let mut count = 0;
|
||||
for (idx, _) in value.char_indices() {
|
||||
for (count, (idx, _)) in value.char_indices().enumerate() {
|
||||
if count == max_chars {
|
||||
end = Some(idx);
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
|
||||
match end {
|
||||
@@ -174,165 +173,31 @@ struct KnowledgeEntityForTemplate {
|
||||
score: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SearchResultForTemplate {
|
||||
result_type: String,
|
||||
score: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
text_chunk: Option<TextChunkForTemplate>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
knowledge_entity: Option<KnowledgeEntityForTemplate>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AnswerData {
|
||||
search_result: Vec<SearchResultForTemplate>,
|
||||
query_param: String,
|
||||
}
|
||||
|
||||
pub async fn search_result_handler(
|
||||
State(state): State<HtmlState>,
|
||||
Query(params): Query<SearchParams>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
#[derive(Serialize)]
|
||||
struct SearchResultForTemplate {
|
||||
result_type: String,
|
||||
score: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
text_chunk: Option<TextChunkForTemplate>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
knowledge_entity: Option<KnowledgeEntityForTemplate>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct AnswerData {
|
||||
search_result: Vec<SearchResultForTemplate>,
|
||||
query_param: String,
|
||||
}
|
||||
|
||||
let (search_results_for_template, final_query_param_for_template) = if let Some(actual_query) =
|
||||
params.query
|
||||
{
|
||||
let trimmed_query = actual_query.trim();
|
||||
if trimmed_query.is_empty() {
|
||||
(Vec::<SearchResultForTemplate>::new(), String::new())
|
||||
} else {
|
||||
// Use retrieval pipeline Search strategy
|
||||
let config = RetrievalConfig::for_search(SearchTarget::Both);
|
||||
|
||||
// Checkout a reranker lease if pool is available
|
||||
let reranker_lease = match &state.reranker_pool {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let result = retrieval_pipeline::pipeline::run_pipeline(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&state.embedding_provider),
|
||||
trimmed_query,
|
||||
&user.id,
|
||||
config,
|
||||
reranker_lease,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let search_result = match result {
|
||||
StrategyOutput::Search(sr) => sr,
|
||||
_ => SearchResult::new(vec![], vec![]),
|
||||
};
|
||||
|
||||
let mut source_ids = HashSet::new();
|
||||
for chunk_result in &search_result.chunks {
|
||||
source_ids.insert(chunk_result.chunk.source_id.clone());
|
||||
}
|
||||
for entity_result in &search_result.entities {
|
||||
source_ids.insert(entity_result.entity.source_id.clone());
|
||||
}
|
||||
|
||||
let source_label_map = if source_ids.is_empty() {
|
||||
HashMap::new()
|
||||
} else {
|
||||
let record_ids: Vec<RecordId> = source_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
if id.contains(':') {
|
||||
RecordId::from_str(id).ok()
|
||||
} else {
|
||||
Some(RecordId::from_table_key(TextContent::table_name(), id))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut response = state
|
||||
.db
|
||||
.client
|
||||
.query(
|
||||
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
|
||||
)
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.bind(("user_id", user.id.clone()))
|
||||
.bind(("record_ids", record_ids))
|
||||
.await?;
|
||||
let contents: Vec<SourceLabelRow> = response.take(0)?;
|
||||
|
||||
tracing::debug!(
|
||||
source_id_count = source_ids.len(),
|
||||
label_row_count = contents.len(),
|
||||
"Resolved search source labels"
|
||||
);
|
||||
|
||||
let mut labels = HashMap::new();
|
||||
for content in contents {
|
||||
let label = build_source_label(&content);
|
||||
labels.insert(content.id.clone(), label.clone());
|
||||
labels.insert(
|
||||
format!("{}:{}", TextContent::table_name(), content.id),
|
||||
label,
|
||||
);
|
||||
}
|
||||
|
||||
labels
|
||||
};
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
|
||||
|
||||
// Add chunk results
|
||||
for chunk_result in search_result.chunks {
|
||||
let source_label = source_label_map
|
||||
.get(&chunk_result.chunk.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "text_chunk".to_string(),
|
||||
score: chunk_result.score,
|
||||
text_chunk: Some(TextChunkForTemplate {
|
||||
id: chunk_result.chunk.id,
|
||||
source_id: chunk_result.chunk.source_id,
|
||||
source_label,
|
||||
chunk: chunk_result.chunk.chunk,
|
||||
score: chunk_result.score,
|
||||
}),
|
||||
knowledge_entity: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Add entity results
|
||||
for entity_result in search_result.entities {
|
||||
let source_label = source_label_map
|
||||
.get(&entity_result.entity.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "knowledge_entity".to_string(),
|
||||
score: entity_result.score,
|
||||
text_chunk: None,
|
||||
knowledge_entity: Some(KnowledgeEntityForTemplate {
|
||||
id: entity_result.entity.id,
|
||||
name: entity_result.entity.name,
|
||||
description: entity_result.entity.description,
|
||||
entity_type: format!("{:?}", entity_result.entity.entity_type),
|
||||
source_id: entity_result.entity.source_id,
|
||||
source_label,
|
||||
score: entity_result.score,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||
|
||||
// Limit results
|
||||
const TOTAL_LIMIT: usize = 10;
|
||||
combined_results.truncate(TOTAL_LIMIT);
|
||||
|
||||
(combined_results, trimmed_query.to_string())
|
||||
}
|
||||
perform_search(&state, &user, actual_query).await?
|
||||
} else {
|
||||
(Vec::<SearchResultForTemplate>::new(), String::new())
|
||||
};
|
||||
@@ -345,3 +210,147 @@ pub async fn search_result_handler(
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
async fn perform_search(
|
||||
state: &HtmlState,
|
||||
user: &User,
|
||||
query: String,
|
||||
) -> Result<(Vec<SearchResultForTemplate>, String), HtmlError> {
|
||||
const TOTAL_LIMIT: usize = 10;
|
||||
|
||||
let trimmed_query = query.trim();
|
||||
if trimmed_query.is_empty() {
|
||||
return Ok((Vec::new(), String::new()));
|
||||
}
|
||||
|
||||
let config = RetrievalConfig::for_search(SearchTarget::Both);
|
||||
|
||||
let reranker_lease = match &state.reranker_pool {
|
||||
Some(pool) => pool.checkout().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
let params = retrieval_pipeline::pipeline::StrategyParams {
|
||||
db_client: &state.db,
|
||||
openai_client: &state.openai_client,
|
||||
embedding_provider: Some(&state.embedding_provider),
|
||||
input_text: trimmed_query,
|
||||
user_id: &user.id,
|
||||
config,
|
||||
reranker: reranker_lease,
|
||||
};
|
||||
let result = retrieval_pipeline::pipeline::execute(params).await?;
|
||||
|
||||
let search_result = match result {
|
||||
StrategyOutput::Search(sr) => sr,
|
||||
_ => SearchResult::new(vec![], vec![]),
|
||||
};
|
||||
|
||||
let source_label_map = resolve_source_labels(state, user, &search_result).await?;
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len().saturating_add(search_result.entities.len()));
|
||||
|
||||
for chunk_result in search_result.chunks {
|
||||
let source_label = source_label_map
|
||||
.get(&chunk_result.chunk.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "text_chunk".to_string(),
|
||||
score: chunk_result.score,
|
||||
text_chunk: Some(TextChunkForTemplate {
|
||||
id: chunk_result.chunk.id,
|
||||
source_id: chunk_result.chunk.source_id,
|
||||
source_label,
|
||||
chunk: chunk_result.chunk.chunk,
|
||||
score: chunk_result.score,
|
||||
}),
|
||||
knowledge_entity: None,
|
||||
});
|
||||
}
|
||||
|
||||
for entity_result in search_result.entities {
|
||||
let source_label = source_label_map
|
||||
.get(&entity_result.entity.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "knowledge_entity".to_string(),
|
||||
score: entity_result.score,
|
||||
text_chunk: None,
|
||||
knowledge_entity: Some(KnowledgeEntityForTemplate {
|
||||
id: entity_result.entity.id,
|
||||
name: entity_result.entity.name,
|
||||
description: entity_result.entity.description,
|
||||
entity_type: format!("{:?}", entity_result.entity.entity_type),
|
||||
source_id: entity_result.entity.source_id,
|
||||
source_label,
|
||||
score: entity_result.score,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||
combined_results.truncate(TOTAL_LIMIT);
|
||||
|
||||
Ok((combined_results, trimmed_query.to_string()))
|
||||
}
|
||||
|
||||
async fn resolve_source_labels(
|
||||
state: &HtmlState,
|
||||
user: &User,
|
||||
search_result: &SearchResult,
|
||||
) -> Result<HashMap<String, String>, HtmlError> {
|
||||
let mut source_ids = HashSet::new();
|
||||
for chunk_result in &search_result.chunks {
|
||||
source_ids.insert(chunk_result.chunk.source_id.clone());
|
||||
}
|
||||
for entity_result in &search_result.entities {
|
||||
source_ids.insert(entity_result.entity.source_id.clone());
|
||||
}
|
||||
|
||||
if source_ids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let record_ids: Vec<RecordId> = source_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
if id.contains(':') {
|
||||
RecordId::from_str(id).ok()
|
||||
} else {
|
||||
Some(RecordId::from_table_key(TextContent::table_name(), id))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut response = state
|
||||
.db
|
||||
.client
|
||||
.query(
|
||||
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
|
||||
)
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.bind(("user_id", user.id.clone()))
|
||||
.bind(("record_ids", record_ids))
|
||||
.await?;
|
||||
let contents: Vec<SourceLabelRow> = response.take(0)?;
|
||||
|
||||
tracing::debug!(
|
||||
source_id_count = source_ids.len(),
|
||||
label_row_count = contents.len(),
|
||||
"Resolved search source labels"
|
||||
);
|
||||
|
||||
let mut labels = HashMap::new();
|
||||
for content in contents {
|
||||
let label = build_source_label(&content);
|
||||
labels.insert(content.id.clone(), label.clone());
|
||||
labels.insert(
|
||||
format!("{}:{}", TextContent::table_name(), content.id),
|
||||
label,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(labels)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
mod handlers;
|
||||
|
||||
use axum::{extract::FromRef, routing::get, Router};
|
||||
pub use handlers::{search_result_handler, SearchParams};
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub use handlers::{
|
||||
search_result_handler as result_handler, SearchParams as SearchQueryParams,
|
||||
};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
@@ -10,5 +13,5 @@ where
|
||||
S: Clone + Send + Sync + 'static,
|
||||
HtmlState: FromRef<S>,
|
||||
{
|
||||
Router::new().route("/search", get(search_result_handler))
|
||||
Router::new().route("/search", get(result_handler))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user