passed wide smoke check

This commit is contained in:
Per Stark
2025-12-10 13:54:08 +01:00
parent 2e2ea0c4ff
commit a5bc72aedf
12 changed files with 403 additions and 235 deletions

4
.gitignore vendored
View File

@@ -10,8 +10,8 @@ result
data data
database database
eval/cache/ evaluations/cache/
eval/reports/ evaluations/reports/
# Devenv # Devenv
.devenv* .devenv*

View File

@@ -97,26 +97,6 @@ impl SurrealDbClient {
Ok(()) Ok(())
} }
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), AppError> {
debug!("Rebuilding indexes");
let rebuild_sql = r#"
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
REBUILD INDEX IF EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding;
REBUILD INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding;
"#;
self.client
.query(rebuild_sql)
.await
.map_err(|e| AppError::InternalError(e.to_string()))?;
Ok(())
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject /// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
/// ///
/// # Arguments /// # Arguments

View File

@@ -1,19 +1,9 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
clippy::items_after_statements,
clippy::arithmetic_side_effects,
clippy::cast_precision_loss,
clippy::redundant_closure_for_method_calls,
clippy::single_match_else,
clippy::uninlined_format_args
)]
use std::time::Duration; use std::time::Duration;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use futures::future::try_join_all; use futures::future::try_join_all;
use serde::Deserialize; use serde::Deserialize;
use serde_json::Value; use serde_json::{Map, Value};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::{error::AppError, storage::db::SurrealDbClient}; use crate::{error::AppError, storage::db::SurrealDbClient};
@@ -28,6 +18,82 @@ struct HnswIndexSpec {
options: &'static str, options: &'static str,
} }
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
[
HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
]
}
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
[
FtsIndexSpec {
index_name: "text_content_fts_idx",
table: "text_content",
field: "text",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_context_fts_idx",
table: "text_content",
field: "context",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_file_name_fts_idx",
table: "text_content",
field: "file_info.file_name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_fts_idx",
table: "text_content",
field: "url_info.url",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_title_fts_idx",
table: "text_content",
field: "url_info.title",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_name_idx",
table: "knowledge_entity",
field: "name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_description_idx",
table: "knowledge_entity",
field: "description",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_chunk_fts_chunk_idx",
table: "text_chunk",
field: "chunk",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
]
}
impl HnswIndexSpec { impl HnswIndexSpec {
fn definition_if_not_exists(&self, dimension: usize) -> String { fn definition_if_not_exists(&self, dimension: usize) -> String {
format!( format!(
@@ -75,6 +141,20 @@ impl FtsIndexSpec {
field = self.field, field = self.field,
) )
} }
fn overwrite_definition(&self) -> String {
let analyzer_clause = self
.analyzer
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
.unwrap_or_default();
format!(
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
index = self.index_name,
table = self.table,
field = self.field,
)
}
} }
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling. /// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
@@ -88,6 +168,13 @@ pub async fn ensure_runtime_indexes(
.map_err(|err| AppError::InternalError(err.to_string())) .map_err(|err| AppError::InternalError(err.to_string()))
} }
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_indexes_inner(db)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
}
async fn ensure_runtime_indexes_inner( async fn ensure_runtime_indexes_inner(
db: &SurrealDbClient, db: &SurrealDbClient,
embedding_dimension: usize, embedding_dimension: usize,
@@ -147,32 +234,68 @@ async fn ensure_runtime_indexes_inner(
Ok(()) Ok(())
} }
async fn hnsw_index_state( async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
debug!("Rebuilding indexes with concurrent definitions");
create_fts_analyzer(db).await?;
for spec in fts_index_specs() {
if !index_exists(db, spec.table, spec.index_name).await? {
debug!(
index = spec.index_name,
table = spec.table,
"Skipping FTS rebuild because index is missing"
);
continue;
}
create_index_with_polling(
db,
spec.overwrite_definition(),
spec.index_name,
spec.table,
Some(spec.table),
)
.await?;
}
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
if !index_exists(db, spec.table, spec.index_name).await? {
debug!(
index = spec.index_name,
table = spec.table,
"Skipping HNSW rebuild because index is missing"
);
return Ok(());
}
let Some(dimension) = existing_hnsw_dimension(db, &spec).await? else {
warn!(
index = spec.index_name,
table = spec.table,
"HNSW index missing dimension; skipping rebuild"
);
return Ok(());
};
create_index_with_polling(
db,
spec.definition_overwrite(dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
});
try_join_all(hnsw_tasks).await.map(|_| ())
}
async fn existing_hnsw_dimension(
db: &SurrealDbClient, db: &SurrealDbClient,
spec: &HnswIndexSpec, spec: &HnswIndexSpec,
expected_dimension: usize, ) -> Result<Option<usize>> {
) -> Result<HnswIndexState> { let Some(indexes) = table_index_definitions(db, spec.table).await? else {
let info_query = format!("INFO FOR TABLE {table};", table = spec.table); return Ok(None);
let mut response = db
.client
.query(info_query)
.await
.with_context(|| format!("fetching table info for {}", spec.table))?;
let info: surrealdb::Value = response
.take(0)
.context("failed to take table info response")?;
let info_json: Value =
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
let Some(indexes) = info_json
.get("Object")
.and_then(|o| o.get("indexes"))
.and_then(|i| i.get("Object"))
.and_then(|i| i.as_object())
else {
return Ok(HnswIndexState::Missing);
}; };
let Some(definition) = indexes let Some(definition) = indexes
@@ -180,17 +303,23 @@ async fn hnsw_index_state(
.and_then(|details| details.get("Strand")) .and_then(|details| details.get("Strand"))
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
else { else {
return Ok(HnswIndexState::Missing); return Ok(None);
}; };
let Some(current_dimension) = extract_dimension(definition) else { Ok(extract_dimension(definition).and_then(|d| usize::try_from(d).ok()))
return Ok(HnswIndexState::Missing); }
};
if current_dimension == expected_dimension as u64 { async fn hnsw_index_state(
Ok(HnswIndexState::Matches) db: &SurrealDbClient,
} else { spec: &HnswIndexSpec,
Ok(HnswIndexState::Different(current_dimension)) expected_dimension: usize,
) -> Result<HnswIndexState> {
match existing_hnsw_dimension(db, spec).await? {
None => Ok(HnswIndexState::Missing),
Some(current_dimension) if current_dimension == expected_dimension => {
Ok(HnswIndexState::Matches)
}
Some(current_dimension) => Ok(HnswIndexState::Different(current_dimension as u64)),
} }
} }
@@ -492,7 +621,10 @@ async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
Ok(rows.first().map_or(0, |r| r.count)) Ok(rows.first().map_or(0, |r| r.count))
} }
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> { async fn table_index_definitions(
db: &SurrealDbClient,
table: &str,
) -> Result<Option<Map<String, Value>>> {
let info_query = format!("INFO FOR TABLE {table};"); let info_query = format!("INFO FOR TABLE {table};");
let mut response = db let mut response = db
.client .client
@@ -507,94 +639,22 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
let info_json: Value = let info_json: Value =
serde_json::to_value(info).context("serializing table info to JSON for parsing")?; serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
let Some(indexes) = info_json Ok(info_json
.get("Object") .get("Object")
.and_then(|o| o.get("indexes")) .and_then(|o| o.get("indexes"))
.and_then(|i| i.get("Object")) .and_then(|i| i.get("Object"))
.and_then(|i| i.as_object()) .and_then(|i| i.as_object())
else { .cloned())
}
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
let Some(indexes) = table_index_definitions(db, table).await? else {
return Ok(false); return Ok(false);
}; };
Ok(indexes.contains_key(index_name)) Ok(indexes.contains_key(index_name))
} }
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
[
HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
]
}
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
[
FtsIndexSpec {
index_name: "text_content_fts_idx",
table: "text_content",
field: "text",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_context_fts_idx",
table: "text_content",
field: "context",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_file_name_fts_idx",
table: "text_content",
field: "file_info.file_name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_fts_idx",
table: "text_content",
field: "url_info.url",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_title_fts_idx",
table: "text_content",
field: "url_info.title",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_name_idx",
table: "knowledge_entity",
field: "name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_description_idx",
table: "knowledge_entity",
field: "description",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_chunk_fts_chunk_idx",
table: "text_chunk",
field: "chunk",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
]
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@@ -17,8 +17,11 @@ pub type DynStore = Arc<dyn ObjectStore>;
/// Storage manager with persistent state and proper lifecycle management. /// Storage manager with persistent state and proper lifecycle management.
#[derive(Clone)] #[derive(Clone)]
pub struct StorageManager { pub struct StorageManager {
// Store from objectstore wrapped as dyn
store: DynStore, store: DynStore,
// Simple enum to track which kind
backend_kind: StorageKind, backend_kind: StorageKind,
// Where on disk
local_base: Option<PathBuf>, local_base: Option<PathBuf>,
} }

View File

@@ -61,36 +61,34 @@ impl TextChunk {
embedding: Vec<f32>, embedding: Vec<f32>,
db: &SurrealDbClient, db: &SurrealDbClient,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
let emb = TextChunkEmbedding::new( let chunk_id = chunk.id.clone();
&chunk.id, let source_id = chunk.source_id.clone();
chunk.source_id.clone(), let user_id = chunk.user_id.clone();
embedding,
chunk.user_id.clone(),
);
// Create both records in a single query let emb = TextChunkEmbedding::new(&chunk_id, source_id.clone(), embedding, user_id.clone());
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
COMMIT TRANSACTION;
",
chunk_table = Self::table_name(),
emb_table = TextChunkEmbedding::table_name(),
);
db.client // Create both records in a single transaction so we don't orphan embeddings or chunks
.query(query) let response = db
.bind(("chunk_id", chunk.id.clone())) .client
.query("BEGIN TRANSACTION;")
.query(format!(
"CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;",
chunk_table = Self::table_name(),
))
.query(format!(
"CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;",
emb_table = TextChunkEmbedding::table_name(),
))
.query("COMMIT TRANSACTION;")
.bind(("chunk_id", chunk_id.clone()))
.bind(("chunk", chunk)) .bind(("chunk", chunk))
.bind(("emb_id", emb.id.clone())) .bind(("emb_id", emb.id.clone()))
.bind(("emb", emb)) .bind(("emb", emb))
.await .await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?; .map_err(AppError::Database)?;
response.check().map_err(AppError::Database)?;
Ok(()) Ok(())
} }
@@ -330,6 +328,7 @@ impl TextChunk {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding; use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use surrealdb::RecordId; use surrealdb::RecordId;
use uuid::Uuid; use uuid::Uuid;
@@ -524,6 +523,46 @@ mod tests {
assert_eq!(embedding.source_id, source_id); assert_eq!(embedding.source_id, source_id);
} }
#[tokio::test]
async fn test_store_with_embedding_with_runtime_indexes() {
let namespace = "test_ns_runtime";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
// Ensure runtime indexes are built with the expected dimension.
let embedding_dimension = 3usize;
ensure_runtime_indexes(&db, embedding_dimension)
.await
.expect("ensure runtime indexes");
let chunk = TextChunk::new(
"runtime_src".to_string(),
"runtime chunk body".to_string(),
"runtime_user".to_string(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store with embedding");
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
assert!(stored_chunk.is_some(), "chunk should be stored");
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await
.expect("get embedding");
assert!(embedding.is_some(), "embedding should exist");
assert_eq!(
embedding.unwrap().embedding.len(),
embedding_dimension,
"embedding dimension should match runtime index"
);
}
#[tokio::test] #[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() { async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns"; let namespace = "test_ns";
@@ -625,7 +664,7 @@ mod tests {
.expect("Failed to start in-memory surrealdb"); .expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations"); db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await; ensure_chunk_fts_index(&db).await;
db.rebuild_indexes().await.expect("rebuild indexes"); rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(5, "hello", &db, "user") let results = TextChunk::fts_search(5, "hello", &db, "user")
.await .await
@@ -651,7 +690,7 @@ mod tests {
user_id.to_string(), user_id.to_string(),
); );
db.store_item(chunk.clone()).await.expect("store chunk"); db.store_item(chunk.clone()).await.expect("store chunk");
db.rebuild_indexes().await.expect("rebuild indexes"); rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "rust", &db, user_id) let results = TextChunk::fts_search(3, "rust", &db, user_id)
.await .await
@@ -698,7 +737,7 @@ mod tests {
db.store_item(other_user_chunk) db.store_item(other_user_chunk)
.await .await
.expect("store other user chunk"); .expect("store other user chunk");
db.rebuild_indexes().await.expect("rebuild indexes"); rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "apple", &db, user_id) let results = TextChunk::fts_search(3, "apple", &db, user_id)
.await .await

View File

@@ -20,8 +20,8 @@ use crate::{
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EmbeddingBackend { pub enum EmbeddingBackend {
OpenAI,
#[default] #[default]
OpenAI,
FastEmbed, FastEmbed,
Hashed, Hashed,
} }
@@ -276,9 +276,7 @@ fn bucket(token: &str, dimension: usize) -> usize {
let safe_dimension = dimension.max(1); let safe_dimension = dimension.max(1);
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
token.hash(&mut hasher); token.hash(&mut hasher);
usize::try_from(hasher.finish()) usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
.unwrap_or_default()
% safe_dimension
} }
// Backward compatibility function // Backward compatibility function

View File

@@ -215,16 +215,30 @@ impl EvaluationCandidate {
} }
} }
fn candidates_from_entities(entities: Vec<RetrievedEntity>) -> Vec<EvaluationCandidate> {
entities
.into_iter()
.map(EvaluationCandidate::from_entity)
.collect()
}
fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidate> {
chunks
.into_iter()
.map(EvaluationCandidate::from_chunk)
.collect()
}
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> { pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
match output { match output {
StrategyOutput::Entities(entities) => entities StrategyOutput::Entities(entities) => candidates_from_entities(entities),
.into_iter() StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks),
.map(EvaluationCandidate::from_entity) StrategyOutput::Search(search_result) => {
.collect(), let mut candidates = candidates_from_entities(search_result.entities);
StrategyOutput::Chunks(chunks) => chunks candidates.extend(candidates_from_chunks(search_result.chunks));
.into_iter() candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
.map(EvaluationCandidate::from_chunk) candidates
.collect(), }
} }
} }

View File

@@ -151,6 +151,10 @@ pub async fn get_response_stream(
retrieval_pipeline::StrategyOutput::Entities(entities) => { retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(&entities) retrieved_entities_to_json(&entities)
} }
retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result
chunks_to_chat_context(&search_result.chunks)
}
}; };
let formatted_user_message = let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content); create_user_message_with_history(&context_json, &history, &user_message.content);

View File

@@ -1,17 +1,13 @@
use std::{fmt, str::FromStr}; use std::{fmt, str::FromStr, time::Duration};
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
response::IntoResponse, response::IntoResponse,
}; };
use common::storage::types::{ use common::storage::types::{conversation::Conversation, user::User};
conversation::Conversation, use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
knowledge_entity::{KnowledgeEntity, KnowledgeEntitySearchResult},
text_content::{TextContent, TextContentSearchResult},
user::User,
};
use futures::future::try_join;
use serde::{de, Deserialize, Deserializer, Serialize}; use serde::{de, Deserialize, Deserializer, Serialize};
use tokio::time::error::Elapsed;
use crate::{ use crate::{
html_state::HtmlState, html_state::HtmlState,
@@ -20,6 +16,7 @@ use crate::{
response_middleware::{HtmlError, TemplateResponse}, response_middleware::{HtmlError, TemplateResponse},
}, },
}; };
/// Serde deserialization decorator to map empty Strings to None, /// Serde deserialization decorator to map empty Strings to None,
fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error> fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
where where
@@ -40,6 +37,26 @@ pub struct SearchParams {
query: Option<String>, query: Option<String>,
} }
/// Chunk result for template rendering
#[derive(Serialize)]
struct TextChunkForTemplate {
id: String,
source_id: String,
chunk: String,
score: f32,
}
/// Entity result for template rendering (from pipeline)
#[derive(Serialize)]
struct KnowledgeEntityForTemplate {
id: String,
name: String,
description: String,
entity_type: String,
source_id: String,
score: f32,
}
pub async fn search_result_handler( pub async fn search_result_handler(
State(state): State<HtmlState>, State(state): State<HtmlState>,
Query(params): Query<SearchParams>, Query(params): Query<SearchParams>,
@@ -50,9 +67,9 @@ pub async fn search_result_handler(
result_type: String, result_type: String,
score: f32, score: f32,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
text_content: Option<TextContentSearchResult>, text_chunk: Option<TextChunkForTemplate>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
knowledge_entity: Option<KnowledgeEntitySearchResult>, knowledge_entity: Option<KnowledgeEntityForTemplate>,
} }
#[derive(Serialize)] #[derive(Serialize)]
@@ -70,37 +87,64 @@ pub async fn search_result_handler(
if trimmed_query.is_empty() { if trimmed_query.is_empty() {
(Vec::<SearchResultForTemplate>::new(), String::new()) (Vec::<SearchResultForTemplate>::new(), String::new())
} else { } else {
const TOTAL_LIMIT: usize = 10; // Use retrieval pipeline Search strategy
let (text_results, entity_results) = try_join( let config = RetrievalConfig::for_search(SearchTarget::Both);
TextContent::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT), let result = retrieval_pipeline::pipeline::run_pipeline(
KnowledgeEntity::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT), &state.db,
&state.openai_client,
None, // No embedding provider in HtmlState
trimmed_query,
&user.id,
config,
None, // No reranker for now
) )
.await?; .await?;
let mut combined_results: Vec<SearchResultForTemplate> = let search_result = match result {
Vec::with_capacity(text_results.len() + entity_results.len()); StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]),
};
for text_result in text_results { let mut combined_results: Vec<SearchResultForTemplate> =
let score = text_result.score; Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
// Add chunk results
for chunk_result in search_result.chunks {
combined_results.push(SearchResultForTemplate { combined_results.push(SearchResultForTemplate {
result_type: "text_content".to_string(), result_type: "text_chunk".to_string(),
score, score: chunk_result.score,
text_content: Some(text_result), text_chunk: Some(TextChunkForTemplate {
id: chunk_result.chunk.id,
source_id: chunk_result.chunk.source_id,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score,
}),
knowledge_entity: None, knowledge_entity: None,
}); });
} }
for entity_result in entity_results { // Add entity results
let score = entity_result.score; for entity_result in search_result.entities {
combined_results.push(SearchResultForTemplate { combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(), result_type: "knowledge_entity".to_string(),
score, score: entity_result.score,
text_content: None, text_chunk: None,
knowledge_entity: Some(entity_result), knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
score: entity_result.score,
}),
}); });
} }
// Sort by score descending
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score)); combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
// Limit results
const TOTAL_LIMIT: usize = 10;
combined_results.truncate(TOTAL_LIMIT); combined_results.truncate(TOTAL_LIMIT);
(combined_results, trimmed_query.to_string()) (combined_results, trimmed_query.to_string())

View File

@@ -184,7 +184,8 @@ impl PipelineServices for DefaultPipelineServices {
None => None, None => None,
}; };
let config = retrieval_pipeline::RetrievalConfig::for_ingestion(); let config =
retrieval_pipeline::RetrievalConfig::for_search(retrieval_pipeline::SearchTarget::EntitiesOnly);
match retrieval_pipeline::retrieve_entities( match retrieval_pipeline::retrieve_entities(
&self.db, &self.db,
&self.openai_client, &self.openai_client,
@@ -197,6 +198,16 @@ impl PipelineServices for DefaultPipelineServices {
.await .await
{ {
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities), Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
Ok(retrieval_pipeline::StrategyOutput::Search(search)) => {
let chunk_count = search.chunks.len();
let entities = search.entities;
tracing::debug!(
chunk_count,
entity_count = entities.len(),
"ingestion search results returned entities"
);
Ok(entities)
}
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError( Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
"Ingestion retrieval should return entities".into(), "Ingestion retrieval should return entities".into(),
)), )),

View File

@@ -4,6 +4,7 @@ use common::{
error::AppError, error::AppError,
storage::{ storage::{
db::SurrealDbClient, db::SurrealDbClient,
indexes::rebuild_indexes,
types::{ types::{
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity, ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk, knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
@@ -131,7 +132,7 @@ pub async fn enrich(
}); });
return machine return machine
.enrich() .enrich()
.map_err(|(_, guard)| map_guard_error("enrich", &guard)); .map_err(|(_, guard)| map_guard_error("enrich", &guard));
} }
let content = ctx.text_content()?; let content = ctx.text_content()?;
@@ -173,18 +174,24 @@ pub async fn persist(
let entity_count = entities.len(); let entity_count = entities.len();
let relationship_count = relationships.len(); let relationship_count = relationships.len();
let ((), chunk_count) = tokio::try_join!( debug!("Were storing chunks");
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships), let chunk_count = store_vector_chunks(
store_vector_chunks( ctx.db,
ctx.db, ctx.task_id.as_str(),
ctx.task_id.as_str(), &chunks,
&chunks, &ctx.pipeline_config.tuning,
&ctx.pipeline_config.tuning )
) .await?;
)?;
debug!("We stored chunks");
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?;
debug!("Stored graph entities");
ctx.db.store_item(text_content).await?; ctx.db.store_item(text_content).await?;
ctx.db.rebuild_indexes().await?;
debug!("stored item");
rebuild_indexes(ctx.db).await?;
debug!( debug!(
task_id = %ctx.task_id, task_id = %ctx.task_id,
@@ -268,17 +275,9 @@ async fn store_vector_chunks(
let chunk_count = chunks.len(); let chunk_count = chunks.len();
let batch_size = tuning.chunk_insert_concurrency.max(1); let batch_size = tuning.chunk_insert_concurrency.max(1);
for embedded in chunks {
debug!(
task_id = %task_id,
chunk_id = %embedded.chunk.id,
chunk_len = embedded.chunk.chunk.chars().count(),
"chunk persisted"
);
}
for batch in chunks.chunks(batch_size) { for batch in chunks.chunks(batch_size) {
store_chunk_batch(db, batch, tuning).await?; store_chunk_batch(db, batch, tuning, task_id).await?;
} }
Ok(chunk_count) Ok(chunk_count)
@@ -294,14 +293,25 @@ async fn store_chunk_batch(
db: &SurrealDbClient, db: &SurrealDbClient,
batch: &[EmbeddedTextChunk], batch: &[EmbeddedTextChunk],
_tuning: &super::config::IngestionTuning, _tuning: &super::config::IngestionTuning,
task_id: &str,
) -> Result<(), AppError> { ) -> Result<(), AppError> {
if batch.is_empty() { if batch.is_empty() {
return Ok(()); return Ok(());
} }
for embedded in batch { for embedded in batch {
TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db) TextChunk::store_with_embedding(
.await?; embedded.chunk.to_owned(),
embedded.embedding.to_owned(),
db,
)
.await?;
debug!(
task_id = %task_id,
chunk_id = %embedded.chunk.id,
chunk_len = embedded.chunk.chunk.chars().count(),
"chunk persisted"
);
} }
Ok(()) Ok(())

View File

@@ -2,7 +2,9 @@ use api_router::{api_routes_v1, api_state::ApiState};
use axum::{extract::FromRef, Router}; use axum::{extract::FromRef, Router};
use common::{ use common::{
storage::{ storage::{
db::SurrealDbClient, indexes::ensure_runtime_indexes, store::StorageManager, db::SurrealDbClient,
indexes::ensure_runtime_indexes,
store::StorageManager,
types::system_settings::SystemSettings, types::system_settings::SystemSettings,
}, },
utils::config::get_config, utils::config::get_config,
@@ -112,7 +114,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await .await
.unwrap(), .unwrap(),
); );
let _settings = SystemSettings::get_current(&worker_db) let settings = SystemSettings::get_current(&worker_db)
.await .await
.expect("failed to load system settings"); .expect("failed to load system settings");
@@ -125,9 +127,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create embedding provider for ingestion // Create embedding provider for ingestion
let embedding_provider = Arc::new( let embedding_provider = Arc::new(
common::utils::embedding::EmbeddingProvider::new_fastembed(None) common::utils::embedding::EmbeddingProvider::new_openai(
.await openai_client.clone(),
.expect("failed to create embedding provider"), settings.embedding_model,
settings.embedding_dimensions,
)
.expect("failed to create embedding provider"),
); );
let ingestion_pipeline = Arc::new( let ingestion_pipeline = Arc::new(
IngestionPipeline::new( IngestionPipeline::new(