mirror of
https://github.com/perstarkse/minne.git
synced 2026-05-02 05:34:16 +02:00
tidying stuff up, dto for search
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use common::storage::{db::SurrealDbClient, store::StorageManager};
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
use common::utils::template_engine::{ProvidesTemplateEngine, TemplateEngine};
|
||||
use common::{create_template_engine, storage::db::ProvidesDb, utils::config::AppConfig};
|
||||
use retrieval_pipeline::{reranking::RerankerPool, RetrievalStrategy};
|
||||
@@ -16,6 +17,7 @@ pub struct HtmlState {
|
||||
pub config: AppConfig,
|
||||
pub storage: StorageManager,
|
||||
pub reranker_pool: Option<Arc<RerankerPool>>,
|
||||
pub embedding_provider: Arc<EmbeddingProvider>,
|
||||
}
|
||||
|
||||
impl HtmlState {
|
||||
@@ -26,6 +28,7 @@ impl HtmlState {
|
||||
storage: StorageManager,
|
||||
config: AppConfig,
|
||||
reranker_pool: Option<Arc<RerankerPool>>,
|
||||
embedding_provider: Arc<EmbeddingProvider>,
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let template_engine = create_template_engine!("templates");
|
||||
debug!("Template engine created for html_router.");
|
||||
@@ -38,6 +41,7 @@ impl HtmlState {
|
||||
config,
|
||||
storage,
|
||||
reranker_pool,
|
||||
embedding_provider,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -151,21 +151,46 @@ pub async fn update_model_settings(
|
||||
|
||||
let current_settings = SystemSettings::get_current(&state.db).await?;
|
||||
|
||||
// Determine if re-embedding is required
|
||||
let reembedding_needed = input
|
||||
.embedding_dimensions
|
||||
.is_some_and(|new_dims| new_dims != current_settings.embedding_dimensions);
|
||||
// Check if using FastEmbed - if so, embedding model/dimensions cannot be changed via UI
|
||||
let uses_local_embeddings = current_settings
|
||||
.embedding_backend
|
||||
.as_deref()
|
||||
.is_some_and(|b| b == "fastembed" || b == "hashed");
|
||||
|
||||
// For local embeddings, ignore any embedding model/dimension changes from the form
|
||||
let (final_embedding_model, final_embedding_dimensions, reembedding_needed) =
|
||||
if uses_local_embeddings {
|
||||
// Keep current values - they're controlled by config, not the admin UI
|
||||
info!(
|
||||
backend = ?current_settings.embedding_backend,
|
||||
"Embedding model/dimensions controlled by config, ignoring form input"
|
||||
);
|
||||
(
|
||||
current_settings.embedding_model.clone(),
|
||||
current_settings.embedding_dimensions,
|
||||
false,
|
||||
)
|
||||
} else {
|
||||
// OpenAI backend - allow changes from form
|
||||
let reembedding_needed = input
|
||||
.embedding_dimensions
|
||||
.is_some_and(|new_dims| new_dims != current_settings.embedding_dimensions);
|
||||
(
|
||||
input.embedding_model,
|
||||
input
|
||||
.embedding_dimensions
|
||||
.unwrap_or(current_settings.embedding_dimensions),
|
||||
reembedding_needed,
|
||||
)
|
||||
};
|
||||
|
||||
let new_settings = SystemSettings {
|
||||
query_model: input.query_model,
|
||||
processing_model: input.processing_model,
|
||||
image_processing_model: input.image_processing_model,
|
||||
voice_processing_model: input.voice_processing_model,
|
||||
embedding_model: input.embedding_model,
|
||||
// Use new dimensions if provided, otherwise retain the current ones.
|
||||
embedding_dimensions: input
|
||||
.embedding_dimensions
|
||||
.unwrap_or(current_settings.embedding_dimensions),
|
||||
embedding_model: final_embedding_model,
|
||||
embedding_dimensions: final_embedding_dimensions,
|
||||
..current_settings.clone()
|
||||
};
|
||||
|
||||
|
||||
@@ -132,6 +132,7 @@ pub async fn get_response_stream(
|
||||
let retrieval_result = match retrieval_pipeline::retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&*state.embedding_provider),
|
||||
&user_message.content,
|
||||
&user.id,
|
||||
config,
|
||||
|
||||
@@ -288,6 +288,7 @@ pub async fn suggest_knowledge_relationships(
|
||||
retrieval_pipeline::retrieve_entities(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
Some(&*state.embedding_provider),
|
||||
&query,
|
||||
&user.id,
|
||||
config,
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
use std::{fmt, str::FromStr, time::Duration};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fmt, str::FromStr,
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use common::storage::types::{conversation::Conversation, user::User};
|
||||
use common::storage::types::{
|
||||
conversation::Conversation,
|
||||
text_content::{deserialize_flexible_id, TextContent},
|
||||
user::User,
|
||||
StoredObject,
|
||||
};
|
||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||
use tokio::time::error::Elapsed;
|
||||
use surrealdb::RecordId;
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
@@ -31,6 +39,113 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn source_id_suffix(source_id: &str) -> String {
|
||||
let start = source_id.len().saturating_sub(8);
|
||||
source_id[start..].to_string()
|
||||
}
|
||||
|
||||
fn truncate_label(value: &str, max_chars: usize) -> String {
|
||||
let mut end = None;
|
||||
let mut count = 0;
|
||||
for (idx, _) in value.char_indices() {
|
||||
if count == max_chars {
|
||||
end = Some(idx);
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
|
||||
match end {
|
||||
Some(idx) => format!("{}...", &value[..idx]),
|
||||
None => value.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn first_non_empty_line(text: &str, max_chars: usize) -> Option<String> {
|
||||
for line in text.lines() {
|
||||
let trimmed = line.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(truncate_label(trimmed, max_chars));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UrlInfoLabel {
|
||||
#[serde(default)]
|
||||
title: String,
|
||||
#[serde(default)]
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct FileInfoLabel {
|
||||
#[serde(default)]
|
||||
file_name: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SourceLabelRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
url_info: Option<UrlInfoLabel>,
|
||||
#[serde(default)]
|
||||
file_info: Option<FileInfoLabel>,
|
||||
#[serde(default)]
|
||||
context: Option<String>,
|
||||
#[serde(default)]
|
||||
category: String,
|
||||
#[serde(default)]
|
||||
text: String,
|
||||
}
|
||||
|
||||
fn build_source_label(row: &SourceLabelRow) -> String {
|
||||
const MAX_LABEL_CHARS: usize = 80;
|
||||
|
||||
if let Some(url_info) = row.url_info.as_ref() {
|
||||
let title = url_info.title.trim();
|
||||
if !title.is_empty() {
|
||||
return title.to_string();
|
||||
}
|
||||
|
||||
let url = url_info.url.trim();
|
||||
if !url.is_empty() {
|
||||
return url.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(file_info) = row.file_info.as_ref() {
|
||||
let name = file_info.file_name.trim();
|
||||
if !name.is_empty() {
|
||||
return name.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(context) = row.context.as_ref() {
|
||||
let trimmed = context.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return truncate_label(trimmed, MAX_LABEL_CHARS);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(text_label) = first_non_empty_line(&row.text, MAX_LABEL_CHARS) {
|
||||
return text_label;
|
||||
}
|
||||
|
||||
let category = row.category.trim();
|
||||
if !category.is_empty() {
|
||||
return truncate_label(category, MAX_LABEL_CHARS);
|
||||
}
|
||||
|
||||
format!("Text snippet: {}", source_id_suffix(&row.id))
|
||||
}
|
||||
|
||||
fn fallback_source_label(source_id: &str) -> String {
|
||||
format!("Text snippet: {}", source_id_suffix(source_id))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SearchParams {
|
||||
#[serde(default, deserialize_with = "empty_string_as_none")]
|
||||
@@ -42,6 +157,7 @@ pub struct SearchParams {
|
||||
struct TextChunkForTemplate {
|
||||
id: String,
|
||||
source_id: String,
|
||||
source_label: String,
|
||||
chunk: String,
|
||||
score: f32,
|
||||
}
|
||||
@@ -54,6 +170,7 @@ struct KnowledgeEntityForTemplate {
|
||||
description: String,
|
||||
entity_type: String,
|
||||
source_id: String,
|
||||
source_label: String,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
@@ -89,14 +206,21 @@ pub async fn search_result_handler(
|
||||
} else {
|
||||
// Use retrieval pipeline Search strategy
|
||||
let config = RetrievalConfig::for_search(SearchTarget::Both);
|
||||
|
||||
// Checkout a reranker lease if pool is available
|
||||
let reranker_lease = match &state.reranker_pool {
|
||||
Some(pool) => Some(pool.checkout().await),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let result = retrieval_pipeline::pipeline::run_pipeline(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
None, // No embedding provider in HtmlState
|
||||
Some(&state.embedding_provider),
|
||||
trimmed_query,
|
||||
&user.id,
|
||||
config,
|
||||
None, // No reranker for now
|
||||
reranker_lease,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -105,17 +229,74 @@ pub async fn search_result_handler(
|
||||
_ => SearchResult::new(vec![], vec![]),
|
||||
};
|
||||
|
||||
let mut source_ids = HashSet::new();
|
||||
for chunk_result in &search_result.chunks {
|
||||
source_ids.insert(chunk_result.chunk.source_id.clone());
|
||||
}
|
||||
for entity_result in &search_result.entities {
|
||||
source_ids.insert(entity_result.entity.source_id.clone());
|
||||
}
|
||||
|
||||
let source_label_map = if source_ids.is_empty() {
|
||||
HashMap::new()
|
||||
} else {
|
||||
let record_ids: Vec<RecordId> = source_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
if id.contains(':') {
|
||||
RecordId::from_str(id).ok()
|
||||
} else {
|
||||
Some(RecordId::from_table_key(TextContent::table_name(), id))
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let mut response = state
|
||||
.db
|
||||
.client
|
||||
.query(
|
||||
"SELECT id, url_info, file_info, context, category, text FROM type::table($table_name) WHERE user_id = $user_id AND id INSIDE $record_ids",
|
||||
)
|
||||
.bind(("table_name", TextContent::table_name()))
|
||||
.bind(("user_id", user.id.clone()))
|
||||
.bind(("record_ids", record_ids))
|
||||
.await?;
|
||||
let contents: Vec<SourceLabelRow> = response.take(0)?;
|
||||
|
||||
tracing::debug!(
|
||||
source_id_count = source_ids.len(),
|
||||
label_row_count = contents.len(),
|
||||
"Resolved search source labels"
|
||||
);
|
||||
|
||||
let mut labels = HashMap::new();
|
||||
for content in contents {
|
||||
let label = build_source_label(&content);
|
||||
labels.insert(content.id.clone(), label.clone());
|
||||
labels.insert(
|
||||
format!("{}:{}", TextContent::table_name(), content.id),
|
||||
label,
|
||||
);
|
||||
}
|
||||
|
||||
labels
|
||||
};
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
|
||||
|
||||
// Add chunk results
|
||||
for chunk_result in search_result.chunks {
|
||||
let source_label = source_label_map
|
||||
.get(&chunk_result.chunk.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&chunk_result.chunk.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "text_chunk".to_string(),
|
||||
score: chunk_result.score,
|
||||
text_chunk: Some(TextChunkForTemplate {
|
||||
id: chunk_result.chunk.id,
|
||||
source_id: chunk_result.chunk.source_id,
|
||||
source_label,
|
||||
chunk: chunk_result.chunk.chunk,
|
||||
score: chunk_result.score,
|
||||
}),
|
||||
@@ -125,6 +306,10 @@ pub async fn search_result_handler(
|
||||
|
||||
// Add entity results
|
||||
for entity_result in search_result.entities {
|
||||
let source_label = source_label_map
|
||||
.get(&entity_result.entity.source_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| fallback_source_label(&entity_result.entity.source_id));
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "knowledge_entity".to_string(),
|
||||
score: entity_result.score,
|
||||
@@ -135,6 +320,7 @@ pub async fn search_result_handler(
|
||||
description: entity_result.entity.description,
|
||||
entity_type: format!("{:?}", entity_result.entity.entity_type),
|
||||
source_id: entity_result.entity.source_id,
|
||||
source_label,
|
||||
score: entity_result.score,
|
||||
}),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user