passed wide smoke check

This commit is contained in:
Per Stark
2025-12-10 13:54:08 +01:00
parent 2e2ea0c4ff
commit a5bc72aedf
12 changed files with 403 additions and 235 deletions

4
.gitignore vendored
View File

@@ -10,8 +10,8 @@ result
data
database
eval/cache/
eval/reports/
evaluations/cache/
evaluations/reports/
# Devenv
.devenv*

View File

@@ -97,26 +97,6 @@ impl SurrealDbClient {
Ok(())
}
/// Operation to rebuild indexes
pub async fn rebuild_indexes(&self) -> Result<(), AppError> {
debug!("Rebuilding indexes");
let rebuild_sql = r#"
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
REBUILD INDEX IF EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding;
REBUILD INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding;
"#;
self.client
.query(rebuild_sql)
.await
.map_err(|e| AppError::InternalError(e.to_string()))?;
Ok(())
}
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
///
/// # Arguments

View File

@@ -1,19 +1,9 @@
#![allow(
clippy::missing_docs_in_private_items,
clippy::module_name_repetitions,
clippy::items_after_statements,
clippy::arithmetic_side_effects,
clippy::cast_precision_loss,
clippy::redundant_closure_for_method_calls,
clippy::single_match_else,
clippy::uninlined_format_args
)]
use std::time::Duration;
use anyhow::{Context, Result};
use futures::future::try_join_all;
use serde::Deserialize;
use serde_json::Value;
use serde_json::{Map, Value};
use tracing::{debug, info, warn};
use crate::{error::AppError, storage::db::SurrealDbClient};
@@ -28,6 +18,82 @@ struct HnswIndexSpec {
options: &'static str,
}
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
[
HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
]
}
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
[
FtsIndexSpec {
index_name: "text_content_fts_idx",
table: "text_content",
field: "text",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_context_fts_idx",
table: "text_content",
field: "context",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_file_name_fts_idx",
table: "text_content",
field: "file_info.file_name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_fts_idx",
table: "text_content",
field: "url_info.url",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_title_fts_idx",
table: "text_content",
field: "url_info.title",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_name_idx",
table: "knowledge_entity",
field: "name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_description_idx",
table: "knowledge_entity",
field: "description",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_chunk_fts_chunk_idx",
table: "text_chunk",
field: "chunk",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
]
}
impl HnswIndexSpec {
fn definition_if_not_exists(&self, dimension: usize) -> String {
format!(
@@ -75,6 +141,20 @@ impl FtsIndexSpec {
field = self.field,
)
}
fn overwrite_definition(&self) -> String {
let analyzer_clause = self
.analyzer
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
.unwrap_or_default();
format!(
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
index = self.index_name,
table = self.table,
field = self.field,
)
}
}
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
@@ -88,6 +168,13 @@ pub async fn ensure_runtime_indexes(
.map_err(|err| AppError::InternalError(err.to_string()))
}
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
rebuild_indexes_inner(db)
.await
.map_err(|err| AppError::InternalError(err.to_string()))
}
async fn ensure_runtime_indexes_inner(
db: &SurrealDbClient,
embedding_dimension: usize,
@@ -147,32 +234,68 @@ async fn ensure_runtime_indexes_inner(
Ok(())
}
async fn hnsw_index_state(
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
debug!("Rebuilding indexes with concurrent definitions");
create_fts_analyzer(db).await?;
for spec in fts_index_specs() {
if !index_exists(db, spec.table, spec.index_name).await? {
debug!(
index = spec.index_name,
table = spec.table,
"Skipping FTS rebuild because index is missing"
);
continue;
}
create_index_with_polling(
db,
spec.overwrite_definition(),
spec.index_name,
spec.table,
Some(spec.table),
)
.await?;
}
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
if !index_exists(db, spec.table, spec.index_name).await? {
debug!(
index = spec.index_name,
table = spec.table,
"Skipping HNSW rebuild because index is missing"
);
return Ok(());
}
let Some(dimension) = existing_hnsw_dimension(db, &spec).await? else {
warn!(
index = spec.index_name,
table = spec.table,
"HNSW index missing dimension; skipping rebuild"
);
return Ok(());
};
create_index_with_polling(
db,
spec.definition_overwrite(dimension),
spec.index_name,
spec.table,
Some(spec.table),
)
.await
});
try_join_all(hnsw_tasks).await.map(|_| ())
}
async fn existing_hnsw_dimension(
db: &SurrealDbClient,
spec: &HnswIndexSpec,
expected_dimension: usize,
) -> Result<HnswIndexState> {
let info_query = format!("INFO FOR TABLE {table};", table = spec.table);
let mut response = db
.client
.query(info_query)
.await
.with_context(|| format!("fetching table info for {}", spec.table))?;
let info: surrealdb::Value = response
.take(0)
.context("failed to take table info response")?;
let info_json: Value =
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
let Some(indexes) = info_json
.get("Object")
.and_then(|o| o.get("indexes"))
.and_then(|i| i.get("Object"))
.and_then(|i| i.as_object())
else {
return Ok(HnswIndexState::Missing);
) -> Result<Option<usize>> {
let Some(indexes) = table_index_definitions(db, spec.table).await? else {
return Ok(None);
};
let Some(definition) = indexes
@@ -180,17 +303,23 @@ async fn hnsw_index_state(
.and_then(|details| details.get("Strand"))
.and_then(|v| v.as_str())
else {
return Ok(HnswIndexState::Missing);
return Ok(None);
};
let Some(current_dimension) = extract_dimension(definition) else {
return Ok(HnswIndexState::Missing);
};
Ok(extract_dimension(definition).and_then(|d| usize::try_from(d).ok()))
}
if current_dimension == expected_dimension as u64 {
Ok(HnswIndexState::Matches)
} else {
Ok(HnswIndexState::Different(current_dimension))
async fn hnsw_index_state(
db: &SurrealDbClient,
spec: &HnswIndexSpec,
expected_dimension: usize,
) -> Result<HnswIndexState> {
match existing_hnsw_dimension(db, spec).await? {
None => Ok(HnswIndexState::Missing),
Some(current_dimension) if current_dimension == expected_dimension => {
Ok(HnswIndexState::Matches)
}
Some(current_dimension) => Ok(HnswIndexState::Different(current_dimension as u64)),
}
}
@@ -492,7 +621,10 @@ async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
Ok(rows.first().map_or(0, |r| r.count))
}
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
async fn table_index_definitions(
db: &SurrealDbClient,
table: &str,
) -> Result<Option<Map<String, Value>>> {
let info_query = format!("INFO FOR TABLE {table};");
let mut response = db
.client
@@ -507,94 +639,22 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
let info_json: Value =
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
let Some(indexes) = info_json
Ok(info_json
.get("Object")
.and_then(|o| o.get("indexes"))
.and_then(|i| i.get("Object"))
.and_then(|i| i.as_object())
else {
.cloned())
}
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
let Some(indexes) = table_index_definitions(db, table).await? else {
return Ok(false);
};
Ok(indexes.contains_key(index_name))
}
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
[
HnswIndexSpec {
index_name: "idx_embedding_text_chunk_embedding",
table: "text_chunk_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
HnswIndexSpec {
index_name: "idx_embedding_knowledge_entity_embedding",
table: "knowledge_entity_embedding",
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
},
]
}
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
[
FtsIndexSpec {
index_name: "text_content_fts_idx",
table: "text_content",
field: "text",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_context_fts_idx",
table: "text_content",
field: "context",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_file_name_fts_idx",
table: "text_content",
field: "file_info.file_name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_fts_idx",
table: "text_content",
field: "url_info.url",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_content_url_title_fts_idx",
table: "text_content",
field: "url_info.title",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_name_idx",
table: "knowledge_entity",
field: "name",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "knowledge_entity_fts_description_idx",
table: "knowledge_entity",
field: "description",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
FtsIndexSpec {
index_name: "text_chunk_fts_chunk_idx",
table: "text_chunk",
field: "chunk",
analyzer: Some(FTS_ANALYZER_NAME),
method: "BM25",
},
]
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -17,8 +17,11 @@ pub type DynStore = Arc<dyn ObjectStore>;
/// Storage manager with persistent state and proper lifecycle management.
#[derive(Clone)]
pub struct StorageManager {
// Store from objectstore wrapped as dyn
store: DynStore,
// Simple enum to track which kind
backend_kind: StorageKind,
// Where on disk
local_base: Option<PathBuf>,
}

View File

@@ -61,36 +61,34 @@ impl TextChunk {
embedding: Vec<f32>,
db: &SurrealDbClient,
) -> Result<(), AppError> {
let emb = TextChunkEmbedding::new(
&chunk.id,
chunk.source_id.clone(),
embedding,
chunk.user_id.clone(),
);
let chunk_id = chunk.id.clone();
let source_id = chunk.source_id.clone();
let user_id = chunk.user_id.clone();
// Create both records in a single query
let query = format!(
"
BEGIN TRANSACTION;
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
COMMIT TRANSACTION;
",
chunk_table = Self::table_name(),
emb_table = TextChunkEmbedding::table_name(),
);
let emb = TextChunkEmbedding::new(&chunk_id, source_id.clone(), embedding, user_id.clone());
db.client
.query(query)
.bind(("chunk_id", chunk.id.clone()))
// Create both records in a single transaction so we don't orphan embeddings or chunks
let response = db
.client
.query("BEGIN TRANSACTION;")
.query(format!(
"CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;",
chunk_table = Self::table_name(),
))
.query(format!(
"CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;",
emb_table = TextChunkEmbedding::table_name(),
))
.query("COMMIT TRANSACTION;")
.bind(("chunk_id", chunk_id.clone()))
.bind(("chunk", chunk))
.bind(("emb_id", emb.id.clone()))
.bind(("emb", emb))
.await
.map_err(AppError::Database)?
.check()
.map_err(AppError::Database)?;
response.check().map_err(AppError::Database)?;
Ok(())
}
@@ -330,6 +328,7 @@ impl TextChunk {
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
use surrealdb::RecordId;
use uuid::Uuid;
@@ -524,6 +523,46 @@ mod tests {
assert_eq!(embedding.source_id, source_id);
}
#[tokio::test]
async fn test_store_with_embedding_with_runtime_indexes() {
let namespace = "test_ns_runtime";
let database = &Uuid::new_v4().to_string();
let db = SurrealDbClient::memory(namespace, database)
.await
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
// Ensure runtime indexes are built with the expected dimension.
let embedding_dimension = 3usize;
ensure_runtime_indexes(&db, embedding_dimension)
.await
.expect("ensure runtime indexes");
let chunk = TextChunk::new(
"runtime_src".to_string(),
"runtime chunk body".to_string(),
"runtime_user".to_string(),
);
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
.await
.expect("store with embedding");
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
assert!(stored_chunk.is_some(), "chunk should be stored");
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
.await
.expect("get embedding");
assert!(embedding.is_some(), "embedding should exist");
assert_eq!(
embedding.unwrap().embedding.len(),
embedding_dimension,
"embedding dimension should match runtime index"
);
}
#[tokio::test]
async fn test_vector_search_returns_empty_when_no_embeddings() {
let namespace = "test_ns";
@@ -625,7 +664,7 @@ mod tests {
.expect("Failed to start in-memory surrealdb");
db.apply_migrations().await.expect("migrations");
ensure_chunk_fts_index(&db).await;
db.rebuild_indexes().await.expect("rebuild indexes");
rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(5, "hello", &db, "user")
.await
@@ -651,7 +690,7 @@ mod tests {
user_id.to_string(),
);
db.store_item(chunk.clone()).await.expect("store chunk");
db.rebuild_indexes().await.expect("rebuild indexes");
rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "rust", &db, user_id)
.await
@@ -698,7 +737,7 @@ mod tests {
db.store_item(other_user_chunk)
.await
.expect("store other user chunk");
db.rebuild_indexes().await.expect("rebuild indexes");
rebuild_indexes(&db).await.expect("rebuild indexes");
let results = TextChunk::fts_search(3, "apple", &db, user_id)
.await

View File

@@ -20,8 +20,8 @@ use crate::{
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EmbeddingBackend {
OpenAI,
#[default]
OpenAI,
FastEmbed,
Hashed,
}
@@ -276,9 +276,7 @@ fn bucket(token: &str, dimension: usize) -> usize {
let safe_dimension = dimension.max(1);
let mut hasher = DefaultHasher::new();
token.hash(&mut hasher);
usize::try_from(hasher.finish())
.unwrap_or_default()
% safe_dimension
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
}
// Backward compatibility function

View File

@@ -215,16 +215,30 @@ impl EvaluationCandidate {
}
}
fn candidates_from_entities(entities: Vec<RetrievedEntity>) -> Vec<EvaluationCandidate> {
entities
.into_iter()
.map(EvaluationCandidate::from_entity)
.collect()
}
fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidate> {
chunks
.into_iter()
.map(EvaluationCandidate::from_chunk)
.collect()
}
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
match output {
StrategyOutput::Entities(entities) => entities
.into_iter()
.map(EvaluationCandidate::from_entity)
.collect(),
StrategyOutput::Chunks(chunks) => chunks
.into_iter()
.map(EvaluationCandidate::from_chunk)
.collect(),
StrategyOutput::Entities(entities) => candidates_from_entities(entities),
StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks),
StrategyOutput::Search(search_result) => {
let mut candidates = candidates_from_entities(search_result.entities);
candidates.extend(candidates_from_chunks(search_result.chunks));
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
candidates
}
}
}

View File

@@ -151,6 +151,10 @@ pub async fn get_response_stream(
retrieval_pipeline::StrategyOutput::Entities(entities) => {
retrieved_entities_to_json(&entities)
}
retrieval_pipeline::StrategyOutput::Search(search_result) => {
// For chat, use chunks from the search result
chunks_to_chat_context(&search_result.chunks)
}
};
let formatted_user_message =
create_user_message_with_history(&context_json, &history, &user_message.content);

View File

@@ -1,17 +1,13 @@
use std::{fmt, str::FromStr};
use std::{fmt, str::FromStr, time::Duration};
use axum::{
extract::{Query, State},
response::IntoResponse,
};
use common::storage::types::{
conversation::Conversation,
knowledge_entity::{KnowledgeEntity, KnowledgeEntitySearchResult},
text_content::{TextContent, TextContentSearchResult},
user::User,
};
use futures::future::try_join;
use common::storage::types::{conversation::Conversation, user::User};
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
use serde::{de, Deserialize, Deserializer, Serialize};
use tokio::time::error::Elapsed;
use crate::{
html_state::HtmlState,
@@ -20,6 +16,7 @@ use crate::{
response_middleware::{HtmlError, TemplateResponse},
},
};
/// Serde deserialization decorator to map empty Strings to None,
fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
where
@@ -40,6 +37,26 @@ pub struct SearchParams {
query: Option<String>,
}
/// Chunk result for template rendering
#[derive(Serialize)]
struct TextChunkForTemplate {
id: String,
source_id: String,
chunk: String,
score: f32,
}
/// Entity result for template rendering (from pipeline)
#[derive(Serialize)]
struct KnowledgeEntityForTemplate {
id: String,
name: String,
description: String,
entity_type: String,
source_id: String,
score: f32,
}
pub async fn search_result_handler(
State(state): State<HtmlState>,
Query(params): Query<SearchParams>,
@@ -50,9 +67,9 @@ pub async fn search_result_handler(
result_type: String,
score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
text_content: Option<TextContentSearchResult>,
text_chunk: Option<TextChunkForTemplate>,
#[serde(skip_serializing_if = "Option::is_none")]
knowledge_entity: Option<KnowledgeEntitySearchResult>,
knowledge_entity: Option<KnowledgeEntityForTemplate>,
}
#[derive(Serialize)]
@@ -70,37 +87,64 @@ pub async fn search_result_handler(
if trimmed_query.is_empty() {
(Vec::<SearchResultForTemplate>::new(), String::new())
} else {
const TOTAL_LIMIT: usize = 10;
let (text_results, entity_results) = try_join(
TextContent::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT),
KnowledgeEntity::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT),
// Use retrieval pipeline Search strategy
let config = RetrievalConfig::for_search(SearchTarget::Both);
let result = retrieval_pipeline::pipeline::run_pipeline(
&state.db,
&state.openai_client,
None, // No embedding provider in HtmlState
trimmed_query,
&user.id,
config,
None, // No reranker for now
)
.await?;
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(text_results.len() + entity_results.len());
let search_result = match result {
StrategyOutput::Search(sr) => sr,
_ => SearchResult::new(vec![], vec![]),
};
for text_result in text_results {
let score = text_result.score;
let mut combined_results: Vec<SearchResultForTemplate> =
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
// Add chunk results
for chunk_result in search_result.chunks {
combined_results.push(SearchResultForTemplate {
result_type: "text_content".to_string(),
score,
text_content: Some(text_result),
result_type: "text_chunk".to_string(),
score: chunk_result.score,
text_chunk: Some(TextChunkForTemplate {
id: chunk_result.chunk.id,
source_id: chunk_result.chunk.source_id,
chunk: chunk_result.chunk.chunk,
score: chunk_result.score,
}),
knowledge_entity: None,
});
}
for entity_result in entity_results {
let score = entity_result.score;
// Add entity results
for entity_result in search_result.entities {
combined_results.push(SearchResultForTemplate {
result_type: "knowledge_entity".to_string(),
score,
text_content: None,
knowledge_entity: Some(entity_result),
score: entity_result.score,
text_chunk: None,
knowledge_entity: Some(KnowledgeEntityForTemplate {
id: entity_result.entity.id,
name: entity_result.entity.name,
description: entity_result.entity.description,
entity_type: format!("{:?}", entity_result.entity.entity_type),
source_id: entity_result.entity.source_id,
score: entity_result.score,
}),
});
}
// Sort by score descending
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
// Limit results
const TOTAL_LIMIT: usize = 10;
combined_results.truncate(TOTAL_LIMIT);
(combined_results, trimmed_query.to_string())

View File

@@ -184,7 +184,8 @@ impl PipelineServices for DefaultPipelineServices {
None => None,
};
let config = retrieval_pipeline::RetrievalConfig::for_ingestion();
let config =
retrieval_pipeline::RetrievalConfig::for_search(retrieval_pipeline::SearchTarget::EntitiesOnly);
match retrieval_pipeline::retrieve_entities(
&self.db,
&self.openai_client,
@@ -197,6 +198,16 @@ impl PipelineServices for DefaultPipelineServices {
.await
{
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
Ok(retrieval_pipeline::StrategyOutput::Search(search)) => {
let chunk_count = search.chunks.len();
let entities = search.entities;
tracing::debug!(
chunk_count,
entity_count = entities.len(),
"ingestion search results returned entities"
);
Ok(entities)
}
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
"Ingestion retrieval should return entities".into(),
)),

View File

@@ -4,6 +4,7 @@ use common::{
error::AppError,
storage::{
db::SurrealDbClient,
indexes::rebuild_indexes,
types::{
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
@@ -131,7 +132,7 @@ pub async fn enrich(
});
return machine
.enrich()
.map_err(|(_, guard)| map_guard_error("enrich", &guard));
.map_err(|(_, guard)| map_guard_error("enrich", &guard));
}
let content = ctx.text_content()?;
@@ -173,18 +174,24 @@ pub async fn persist(
let entity_count = entities.len();
let relationship_count = relationships.len();
let ((), chunk_count) = tokio::try_join!(
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships),
store_vector_chunks(
ctx.db,
ctx.task_id.as_str(),
&chunks,
&ctx.pipeline_config.tuning
)
)?;
debug!("Were storing chunks");
let chunk_count = store_vector_chunks(
ctx.db,
ctx.task_id.as_str(),
&chunks,
&ctx.pipeline_config.tuning,
)
.await?;
debug!("We stored chunks");
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?;
debug!("Stored graph entities");
ctx.db.store_item(text_content).await?;
ctx.db.rebuild_indexes().await?;
debug!("stored item");
rebuild_indexes(ctx.db).await?;
debug!(
task_id = %ctx.task_id,
@@ -268,17 +275,9 @@ async fn store_vector_chunks(
let chunk_count = chunks.len();
let batch_size = tuning.chunk_insert_concurrency.max(1);
for embedded in chunks {
debug!(
task_id = %task_id,
chunk_id = %embedded.chunk.id,
chunk_len = embedded.chunk.chunk.chars().count(),
"chunk persisted"
);
}
for batch in chunks.chunks(batch_size) {
store_chunk_batch(db, batch, tuning).await?;
store_chunk_batch(db, batch, tuning, task_id).await?;
}
Ok(chunk_count)
@@ -294,14 +293,25 @@ async fn store_chunk_batch(
db: &SurrealDbClient,
batch: &[EmbeddedTextChunk],
_tuning: &super::config::IngestionTuning,
task_id: &str,
) -> Result<(), AppError> {
if batch.is_empty() {
return Ok(());
}
for embedded in batch {
TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db)
.await?;
TextChunk::store_with_embedding(
embedded.chunk.to_owned(),
embedded.embedding.to_owned(),
db,
)
.await?;
debug!(
task_id = %task_id,
chunk_id = %embedded.chunk.id,
chunk_len = embedded.chunk.chunk.chars().count(),
"chunk persisted"
);
}
Ok(())

View File

@@ -2,7 +2,9 @@ use api_router::{api_routes_v1, api_state::ApiState};
use axum::{extract::FromRef, Router};
use common::{
storage::{
db::SurrealDbClient, indexes::ensure_runtime_indexes, store::StorageManager,
db::SurrealDbClient,
indexes::ensure_runtime_indexes,
store::StorageManager,
types::system_settings::SystemSettings,
},
utils::config::get_config,
@@ -112,7 +114,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.await
.unwrap(),
);
let _settings = SystemSettings::get_current(&worker_db)
let settings = SystemSettings::get_current(&worker_db)
.await
.expect("failed to load system settings");
@@ -125,9 +127,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create embedding provider for ingestion
let embedding_provider = Arc::new(
common::utils::embedding::EmbeddingProvider::new_fastembed(None)
.await
.expect("failed to create embedding provider"),
common::utils::embedding::EmbeddingProvider::new_openai(
openai_client.clone(),
settings.embedding_model,
settings.embedding_dimensions,
)
.expect("failed to create embedding provider"),
);
let ingestion_pipeline = Arc::new(
IngestionPipeline::new(