mirror of
https://github.com/perstarkse/minne.git
synced 2026-01-17 23:46:39 +01:00
passed wide smoke check
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -10,8 +10,8 @@ result
|
||||
data
|
||||
database
|
||||
|
||||
eval/cache/
|
||||
eval/reports/
|
||||
evaluations/cache/
|
||||
evaluations/reports/
|
||||
|
||||
# Devenv
|
||||
.devenv*
|
||||
|
||||
@@ -97,26 +97,6 @@ impl SurrealDbClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Operation to rebuild indexes
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), AppError> {
|
||||
debug!("Rebuilding indexes");
|
||||
let rebuild_sql = r#"
|
||||
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_text_chunk_embedding ON text_chunk_embedding;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON knowledge_entity_embedding;
|
||||
"#;
|
||||
|
||||
self.client
|
||||
.query(rebuild_sql)
|
||||
.await
|
||||
.map_err(|e| AppError::InternalError(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Operation to store a object in SurrealDB, requires the struct to implement StoredObject
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
@@ -1,19 +1,9 @@
|
||||
#![allow(
|
||||
clippy::missing_docs_in_private_items,
|
||||
clippy::module_name_repetitions,
|
||||
clippy::items_after_statements,
|
||||
clippy::arithmetic_side_effects,
|
||||
clippy::cast_precision_loss,
|
||||
clippy::redundant_closure_for_method_calls,
|
||||
clippy::single_match_else,
|
||||
clippy::uninlined_format_args
|
||||
)]
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use futures::future::try_join_all;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::{Map, Value};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
@@ -28,6 +18,82 @@ struct HnswIndexSpec {
|
||||
options: &'static str,
|
||||
}
|
||||
|
||||
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
|
||||
[
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_text_chunk_embedding",
|
||||
table: "text_chunk_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_knowledge_entity_embedding",
|
||||
table: "knowledge_entity_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
|
||||
[
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_fts_idx",
|
||||
table: "text_content",
|
||||
field: "text",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_context_fts_idx",
|
||||
table: "text_content",
|
||||
field: "context",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_file_name_fts_idx",
|
||||
table: "text_content",
|
||||
field: "file_info.file_name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.url",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_title_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.title",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_name_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_description_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "description",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_chunk_fts_chunk_idx",
|
||||
table: "text_chunk",
|
||||
field: "chunk",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
impl HnswIndexSpec {
|
||||
fn definition_if_not_exists(&self, dimension: usize) -> String {
|
||||
format!(
|
||||
@@ -75,6 +141,20 @@ impl FtsIndexSpec {
|
||||
field = self.field,
|
||||
)
|
||||
}
|
||||
|
||||
fn overwrite_definition(&self) -> String {
|
||||
let analyzer_clause = self
|
||||
.analyzer
|
||||
.map(|analyzer| format!(" SEARCH ANALYZER {analyzer} {}", self.method))
|
||||
.unwrap_or_default();
|
||||
|
||||
format!(
|
||||
"DEFINE INDEX OVERWRITE {index} ON TABLE {table} FIELDS {field}{analyzer_clause} CONCURRENTLY;",
|
||||
index = self.index_name,
|
||||
table = self.table,
|
||||
field = self.field,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build runtime Surreal indexes (FTS + HNSW) using concurrent creation with readiness polling.
|
||||
@@ -88,6 +168,13 @@ pub async fn ensure_runtime_indexes(
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
/// Rebuild known FTS and HNSW indexes, skipping any that are not yet defined.
|
||||
pub async fn rebuild_indexes(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_indexes_inner(db)
|
||||
.await
|
||||
.map_err(|err| AppError::InternalError(err.to_string()))
|
||||
}
|
||||
|
||||
async fn ensure_runtime_indexes_inner(
|
||||
db: &SurrealDbClient,
|
||||
embedding_dimension: usize,
|
||||
@@ -147,32 +234,68 @@ async fn ensure_runtime_indexes_inner(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn hnsw_index_state(
|
||||
async fn rebuild_indexes_inner(db: &SurrealDbClient) -> Result<()> {
|
||||
debug!("Rebuilding indexes with concurrent definitions");
|
||||
create_fts_analyzer(db).await?;
|
||||
|
||||
for spec in fts_index_specs() {
|
||||
if !index_exists(db, spec.table, spec.index_name).await? {
|
||||
debug!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
"Skipping FTS rebuild because index is missing"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.overwrite_definition(),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let hnsw_tasks = hnsw_index_specs().into_iter().map(|spec| async move {
|
||||
if !index_exists(db, spec.table, spec.index_name).await? {
|
||||
debug!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
"Skipping HNSW rebuild because index is missing"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(dimension) = existing_hnsw_dimension(db, &spec).await? else {
|
||||
warn!(
|
||||
index = spec.index_name,
|
||||
table = spec.table,
|
||||
"HNSW index missing dimension; skipping rebuild"
|
||||
);
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
create_index_with_polling(
|
||||
db,
|
||||
spec.definition_overwrite(dimension),
|
||||
spec.index_name,
|
||||
spec.table,
|
||||
Some(spec.table),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
try_join_all(hnsw_tasks).await.map(|_| ())
|
||||
}
|
||||
|
||||
async fn existing_hnsw_dimension(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
expected_dimension: usize,
|
||||
) -> Result<HnswIndexState> {
|
||||
let info_query = format!("INFO FOR TABLE {table};", table = spec.table);
|
||||
let mut response = db
|
||||
.client
|
||||
.query(info_query)
|
||||
.await
|
||||
.with_context(|| format!("fetching table info for {}", spec.table))?;
|
||||
|
||||
let info: surrealdb::Value = response
|
||||
.take(0)
|
||||
.context("failed to take table info response")?;
|
||||
|
||||
let info_json: Value =
|
||||
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
|
||||
|
||||
let Some(indexes) = info_json
|
||||
.get("Object")
|
||||
.and_then(|o| o.get("indexes"))
|
||||
.and_then(|i| i.get("Object"))
|
||||
.and_then(|i| i.as_object())
|
||||
else {
|
||||
return Ok(HnswIndexState::Missing);
|
||||
) -> Result<Option<usize>> {
|
||||
let Some(indexes) = table_index_definitions(db, spec.table).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(definition) = indexes
|
||||
@@ -180,17 +303,23 @@ async fn hnsw_index_state(
|
||||
.and_then(|details| details.get("Strand"))
|
||||
.and_then(|v| v.as_str())
|
||||
else {
|
||||
return Ok(HnswIndexState::Missing);
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let Some(current_dimension) = extract_dimension(definition) else {
|
||||
return Ok(HnswIndexState::Missing);
|
||||
};
|
||||
Ok(extract_dimension(definition).and_then(|d| usize::try_from(d).ok()))
|
||||
}
|
||||
|
||||
if current_dimension == expected_dimension as u64 {
|
||||
Ok(HnswIndexState::Matches)
|
||||
} else {
|
||||
Ok(HnswIndexState::Different(current_dimension))
|
||||
async fn hnsw_index_state(
|
||||
db: &SurrealDbClient,
|
||||
spec: &HnswIndexSpec,
|
||||
expected_dimension: usize,
|
||||
) -> Result<HnswIndexState> {
|
||||
match existing_hnsw_dimension(db, spec).await? {
|
||||
None => Ok(HnswIndexState::Missing),
|
||||
Some(current_dimension) if current_dimension == expected_dimension => {
|
||||
Ok(HnswIndexState::Matches)
|
||||
}
|
||||
Some(current_dimension) => Ok(HnswIndexState::Different(current_dimension as u64)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -492,7 +621,10 @@ async fn count_table_rows(db: &SurrealDbClient, table: &str) -> Result<u64> {
|
||||
Ok(rows.first().map_or(0, |r| r.count))
|
||||
}
|
||||
|
||||
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
|
||||
async fn table_index_definitions(
|
||||
db: &SurrealDbClient,
|
||||
table: &str,
|
||||
) -> Result<Option<Map<String, Value>>> {
|
||||
let info_query = format!("INFO FOR TABLE {table};");
|
||||
let mut response = db
|
||||
.client
|
||||
@@ -507,94 +639,22 @@ async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Re
|
||||
let info_json: Value =
|
||||
serde_json::to_value(info).context("serializing table info to JSON for parsing")?;
|
||||
|
||||
let Some(indexes) = info_json
|
||||
Ok(info_json
|
||||
.get("Object")
|
||||
.and_then(|o| o.get("indexes"))
|
||||
.and_then(|i| i.get("Object"))
|
||||
.and_then(|i| i.as_object())
|
||||
else {
|
||||
.cloned())
|
||||
}
|
||||
|
||||
async fn index_exists(db: &SurrealDbClient, table: &str, index_name: &str) -> Result<bool> {
|
||||
let Some(indexes) = table_index_definitions(db, table).await? else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
Ok(indexes.contains_key(index_name))
|
||||
}
|
||||
|
||||
const fn hnsw_index_specs() -> [HnswIndexSpec; 2] {
|
||||
[
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_text_chunk_embedding",
|
||||
table: "text_chunk_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
HnswIndexSpec {
|
||||
index_name: "idx_embedding_knowledge_entity_embedding",
|
||||
table: "knowledge_entity_embedding",
|
||||
options: "DIST COSINE TYPE F32 EFC 100 M 8 CONCURRENTLY",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const fn fts_index_specs() -> [FtsIndexSpec; 8] {
|
||||
[
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_fts_idx",
|
||||
table: "text_content",
|
||||
field: "text",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_context_fts_idx",
|
||||
table: "text_content",
|
||||
field: "context",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_file_name_fts_idx",
|
||||
table: "text_content",
|
||||
field: "file_info.file_name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.url",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_content_url_title_fts_idx",
|
||||
table: "text_content",
|
||||
field: "url_info.title",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_name_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "name",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "knowledge_entity_fts_description_idx",
|
||||
table: "knowledge_entity",
|
||||
field: "description",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
FtsIndexSpec {
|
||||
index_name: "text_chunk_fts_chunk_idx",
|
||||
table: "text_chunk",
|
||||
field: "chunk",
|
||||
analyzer: Some(FTS_ANALYZER_NAME),
|
||||
method: "BM25",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -17,8 +17,11 @@ pub type DynStore = Arc<dyn ObjectStore>;
|
||||
/// Storage manager with persistent state and proper lifecycle management.
|
||||
#[derive(Clone)]
|
||||
pub struct StorageManager {
|
||||
// Store from objectstore wrapped as dyn
|
||||
store: DynStore,
|
||||
// Simple enum to track which kind
|
||||
backend_kind: StorageKind,
|
||||
// Where on disk
|
||||
local_base: Option<PathBuf>,
|
||||
}
|
||||
|
||||
|
||||
@@ -61,36 +61,34 @@ impl TextChunk {
|
||||
embedding: Vec<f32>,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let emb = TextChunkEmbedding::new(
|
||||
&chunk.id,
|
||||
chunk.source_id.clone(),
|
||||
embedding,
|
||||
chunk.user_id.clone(),
|
||||
);
|
||||
let chunk_id = chunk.id.clone();
|
||||
let source_id = chunk.source_id.clone();
|
||||
let user_id = chunk.user_id.clone();
|
||||
|
||||
// Create both records in a single query
|
||||
let query = format!(
|
||||
"
|
||||
BEGIN TRANSACTION;
|
||||
CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;
|
||||
CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;
|
||||
COMMIT TRANSACTION;
|
||||
",
|
||||
chunk_table = Self::table_name(),
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
);
|
||||
let emb = TextChunkEmbedding::new(&chunk_id, source_id.clone(), embedding, user_id.clone());
|
||||
|
||||
db.client
|
||||
.query(query)
|
||||
.bind(("chunk_id", chunk.id.clone()))
|
||||
// Create both records in a single transaction so we don't orphan embeddings or chunks
|
||||
let response = db
|
||||
.client
|
||||
.query("BEGIN TRANSACTION;")
|
||||
.query(format!(
|
||||
"CREATE type::thing('{chunk_table}', $chunk_id) CONTENT $chunk;",
|
||||
chunk_table = Self::table_name(),
|
||||
))
|
||||
.query(format!(
|
||||
"CREATE type::thing('{emb_table}', $emb_id) CONTENT $emb;",
|
||||
emb_table = TextChunkEmbedding::table_name(),
|
||||
))
|
||||
.query("COMMIT TRANSACTION;")
|
||||
.bind(("chunk_id", chunk_id.clone()))
|
||||
.bind(("chunk", chunk))
|
||||
.bind(("emb_id", emb.id.clone()))
|
||||
.bind(("emb", emb))
|
||||
.await
|
||||
.map_err(AppError::Database)?
|
||||
.check()
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
response.check().map_err(AppError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -330,6 +328,7 @@ impl TextChunk {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::indexes::{ensure_runtime_indexes, rebuild_indexes};
|
||||
use crate::storage::types::text_chunk_embedding::TextChunkEmbedding;
|
||||
use surrealdb::RecordId;
|
||||
use uuid::Uuid;
|
||||
@@ -524,6 +523,46 @@ mod tests {
|
||||
assert_eq!(embedding.source_id, source_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_with_embedding_with_runtime_indexes() {
|
||||
let namespace = "test_ns_runtime";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
|
||||
// Ensure runtime indexes are built with the expected dimension.
|
||||
let embedding_dimension = 3usize;
|
||||
ensure_runtime_indexes(&db, embedding_dimension)
|
||||
.await
|
||||
.expect("ensure runtime indexes");
|
||||
|
||||
let chunk = TextChunk::new(
|
||||
"runtime_src".to_string(),
|
||||
"runtime chunk body".to_string(),
|
||||
"runtime_user".to_string(),
|
||||
);
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.1, 0.2, 0.3], &db)
|
||||
.await
|
||||
.expect("store with embedding");
|
||||
|
||||
let stored_chunk: Option<TextChunk> = db.get_item(&chunk.id).await.unwrap();
|
||||
assert!(stored_chunk.is_some(), "chunk should be stored");
|
||||
|
||||
let rid = RecordId::from_table_key(TextChunk::table_name(), &chunk.id);
|
||||
let embedding = TextChunkEmbedding::get_by_chunk_id(&rid, &db)
|
||||
.await
|
||||
.expect("get embedding");
|
||||
assert!(embedding.is_some(), "embedding should exist");
|
||||
assert_eq!(
|
||||
embedding.unwrap().embedding.len(),
|
||||
embedding_dimension,
|
||||
"embedding dimension should match runtime index"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_vector_search_returns_empty_when_no_embeddings() {
|
||||
let namespace = "test_ns";
|
||||
@@ -625,7 +664,7 @@ mod tests {
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
ensure_chunk_fts_index(&db).await;
|
||||
db.rebuild_indexes().await.expect("rebuild indexes");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
|
||||
let results = TextChunk::fts_search(5, "hello", &db, "user")
|
||||
.await
|
||||
@@ -651,7 +690,7 @@ mod tests {
|
||||
user_id.to_string(),
|
||||
);
|
||||
db.store_item(chunk.clone()).await.expect("store chunk");
|
||||
db.rebuild_indexes().await.expect("rebuild indexes");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
|
||||
let results = TextChunk::fts_search(3, "rust", &db, user_id)
|
||||
.await
|
||||
@@ -698,7 +737,7 @@ mod tests {
|
||||
db.store_item(other_user_chunk)
|
||||
.await
|
||||
.expect("store other user chunk");
|
||||
db.rebuild_indexes().await.expect("rebuild indexes");
|
||||
rebuild_indexes(&db).await.expect("rebuild indexes");
|
||||
|
||||
let results = TextChunk::fts_search(3, "apple", &db, user_id)
|
||||
.await
|
||||
|
||||
@@ -20,8 +20,8 @@ use crate::{
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum EmbeddingBackend {
|
||||
OpenAI,
|
||||
#[default]
|
||||
OpenAI,
|
||||
FastEmbed,
|
||||
Hashed,
|
||||
}
|
||||
@@ -276,9 +276,7 @@ fn bucket(token: &str, dimension: usize) -> usize {
|
||||
let safe_dimension = dimension.max(1);
|
||||
let mut hasher = DefaultHasher::new();
|
||||
token.hash(&mut hasher);
|
||||
usize::try_from(hasher.finish())
|
||||
.unwrap_or_default()
|
||||
% safe_dimension
|
||||
usize::try_from(hasher.finish()).unwrap_or_default() % safe_dimension
|
||||
}
|
||||
|
||||
// Backward compatibility function
|
||||
|
||||
@@ -215,16 +215,30 @@ impl EvaluationCandidate {
|
||||
}
|
||||
}
|
||||
|
||||
fn candidates_from_entities(entities: Vec<RetrievedEntity>) -> Vec<EvaluationCandidate> {
|
||||
entities
|
||||
.into_iter()
|
||||
.map(EvaluationCandidate::from_entity)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn candidates_from_chunks(chunks: Vec<RetrievedChunk>) -> Vec<EvaluationCandidate> {
|
||||
chunks
|
||||
.into_iter()
|
||||
.map(EvaluationCandidate::from_chunk)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn adapt_strategy_output(output: StrategyOutput) -> Vec<EvaluationCandidate> {
|
||||
match output {
|
||||
StrategyOutput::Entities(entities) => entities
|
||||
.into_iter()
|
||||
.map(EvaluationCandidate::from_entity)
|
||||
.collect(),
|
||||
StrategyOutput::Chunks(chunks) => chunks
|
||||
.into_iter()
|
||||
.map(EvaluationCandidate::from_chunk)
|
||||
.collect(),
|
||||
StrategyOutput::Entities(entities) => candidates_from_entities(entities),
|
||||
StrategyOutput::Chunks(chunks) => candidates_from_chunks(chunks),
|
||||
StrategyOutput::Search(search_result) => {
|
||||
let mut candidates = candidates_from_entities(search_result.entities);
|
||||
candidates.extend(candidates_from_chunks(search_result.chunks));
|
||||
candidates.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||
candidates
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -151,6 +151,10 @@ pub async fn get_response_stream(
|
||||
retrieval_pipeline::StrategyOutput::Entities(entities) => {
|
||||
retrieved_entities_to_json(&entities)
|
||||
}
|
||||
retrieval_pipeline::StrategyOutput::Search(search_result) => {
|
||||
// For chat, use chunks from the search result
|
||||
chunks_to_chat_context(&search_result.chunks)
|
||||
}
|
||||
};
|
||||
let formatted_user_message =
|
||||
create_user_message_with_history(&context_json, &history, &user_message.content);
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
use std::{fmt, str::FromStr};
|
||||
use std::{fmt, str::FromStr, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use common::storage::types::{
|
||||
conversation::Conversation,
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntitySearchResult},
|
||||
text_content::{TextContent, TextContentSearchResult},
|
||||
user::User,
|
||||
};
|
||||
use futures::future::try_join;
|
||||
use common::storage::types::{conversation::Conversation, user::User};
|
||||
use retrieval_pipeline::{RetrievalConfig, SearchResult, SearchTarget, StrategyOutput};
|
||||
use serde::{de, Deserialize, Deserializer, Serialize};
|
||||
use tokio::time::error::Elapsed;
|
||||
|
||||
use crate::{
|
||||
html_state::HtmlState,
|
||||
@@ -20,6 +16,7 @@ use crate::{
|
||||
response_middleware::{HtmlError, TemplateResponse},
|
||||
},
|
||||
};
|
||||
|
||||
/// Serde deserialization decorator to map empty Strings to None,
|
||||
fn empty_string_as_none<'de, D, T>(de: D) -> Result<Option<T>, D::Error>
|
||||
where
|
||||
@@ -40,6 +37,26 @@ pub struct SearchParams {
|
||||
query: Option<String>,
|
||||
}
|
||||
|
||||
/// Chunk result for template rendering
|
||||
#[derive(Serialize)]
|
||||
struct TextChunkForTemplate {
|
||||
id: String,
|
||||
source_id: String,
|
||||
chunk: String,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
/// Entity result for template rendering (from pipeline)
|
||||
#[derive(Serialize)]
|
||||
struct KnowledgeEntityForTemplate {
|
||||
id: String,
|
||||
name: String,
|
||||
description: String,
|
||||
entity_type: String,
|
||||
source_id: String,
|
||||
score: f32,
|
||||
}
|
||||
|
||||
pub async fn search_result_handler(
|
||||
State(state): State<HtmlState>,
|
||||
Query(params): Query<SearchParams>,
|
||||
@@ -50,9 +67,9 @@ pub async fn search_result_handler(
|
||||
result_type: String,
|
||||
score: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
text_content: Option<TextContentSearchResult>,
|
||||
text_chunk: Option<TextChunkForTemplate>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
knowledge_entity: Option<KnowledgeEntitySearchResult>,
|
||||
knowledge_entity: Option<KnowledgeEntityForTemplate>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -70,37 +87,64 @@ pub async fn search_result_handler(
|
||||
if trimmed_query.is_empty() {
|
||||
(Vec::<SearchResultForTemplate>::new(), String::new())
|
||||
} else {
|
||||
const TOTAL_LIMIT: usize = 10;
|
||||
let (text_results, entity_results) = try_join(
|
||||
TextContent::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT),
|
||||
KnowledgeEntity::search(&state.db, trimmed_query, &user.id, TOTAL_LIMIT),
|
||||
// Use retrieval pipeline Search strategy
|
||||
let config = RetrievalConfig::for_search(SearchTarget::Both);
|
||||
let result = retrieval_pipeline::pipeline::run_pipeline(
|
||||
&state.db,
|
||||
&state.openai_client,
|
||||
None, // No embedding provider in HtmlState
|
||||
trimmed_query,
|
||||
&user.id,
|
||||
config,
|
||||
None, // No reranker for now
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(text_results.len() + entity_results.len());
|
||||
let search_result = match result {
|
||||
StrategyOutput::Search(sr) => sr,
|
||||
_ => SearchResult::new(vec![], vec![]),
|
||||
};
|
||||
|
||||
for text_result in text_results {
|
||||
let score = text_result.score;
|
||||
let mut combined_results: Vec<SearchResultForTemplate> =
|
||||
Vec::with_capacity(search_result.chunks.len() + search_result.entities.len());
|
||||
|
||||
// Add chunk results
|
||||
for chunk_result in search_result.chunks {
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "text_content".to_string(),
|
||||
score,
|
||||
text_content: Some(text_result),
|
||||
result_type: "text_chunk".to_string(),
|
||||
score: chunk_result.score,
|
||||
text_chunk: Some(TextChunkForTemplate {
|
||||
id: chunk_result.chunk.id,
|
||||
source_id: chunk_result.chunk.source_id,
|
||||
chunk: chunk_result.chunk.chunk,
|
||||
score: chunk_result.score,
|
||||
}),
|
||||
knowledge_entity: None,
|
||||
});
|
||||
}
|
||||
|
||||
for entity_result in entity_results {
|
||||
let score = entity_result.score;
|
||||
// Add entity results
|
||||
for entity_result in search_result.entities {
|
||||
combined_results.push(SearchResultForTemplate {
|
||||
result_type: "knowledge_entity".to_string(),
|
||||
score,
|
||||
text_content: None,
|
||||
knowledge_entity: Some(entity_result),
|
||||
score: entity_result.score,
|
||||
text_chunk: None,
|
||||
knowledge_entity: Some(KnowledgeEntityForTemplate {
|
||||
id: entity_result.entity.id,
|
||||
name: entity_result.entity.name,
|
||||
description: entity_result.entity.description,
|
||||
entity_type: format!("{:?}", entity_result.entity.entity_type),
|
||||
source_id: entity_result.entity.source_id,
|
||||
score: entity_result.score,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
combined_results.sort_by(|a, b| b.score.total_cmp(&a.score));
|
||||
|
||||
// Limit results
|
||||
const TOTAL_LIMIT: usize = 10;
|
||||
combined_results.truncate(TOTAL_LIMIT);
|
||||
|
||||
(combined_results, trimmed_query.to_string())
|
||||
|
||||
@@ -184,7 +184,8 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
None => None,
|
||||
};
|
||||
|
||||
let config = retrieval_pipeline::RetrievalConfig::for_ingestion();
|
||||
let config =
|
||||
retrieval_pipeline::RetrievalConfig::for_search(retrieval_pipeline::SearchTarget::EntitiesOnly);
|
||||
match retrieval_pipeline::retrieve_entities(
|
||||
&self.db,
|
||||
&self.openai_client,
|
||||
@@ -197,6 +198,16 @@ impl PipelineServices for DefaultPipelineServices {
|
||||
.await
|
||||
{
|
||||
Ok(retrieval_pipeline::StrategyOutput::Entities(entities)) => Ok(entities),
|
||||
Ok(retrieval_pipeline::StrategyOutput::Search(search)) => {
|
||||
let chunk_count = search.chunks.len();
|
||||
let entities = search.entities;
|
||||
tracing::debug!(
|
||||
chunk_count,
|
||||
entity_count = entities.len(),
|
||||
"ingestion search results returned entities"
|
||||
);
|
||||
Ok(entities)
|
||||
}
|
||||
Ok(retrieval_pipeline::StrategyOutput::Chunks(_)) => Err(AppError::InternalError(
|
||||
"Ingestion retrieval should return entities".into(),
|
||||
)),
|
||||
|
||||
@@ -4,6 +4,7 @@ use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::rebuild_indexes,
|
||||
types::{
|
||||
ingestion_payload::IngestionPayload, knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship, text_chunk::TextChunk,
|
||||
@@ -131,7 +132,7 @@ pub async fn enrich(
|
||||
});
|
||||
return machine
|
||||
.enrich()
|
||||
.map_err(|(_, guard)| map_guard_error("enrich", &guard));
|
||||
.map_err(|(_, guard)| map_guard_error("enrich", &guard));
|
||||
}
|
||||
|
||||
let content = ctx.text_content()?;
|
||||
@@ -173,18 +174,24 @@ pub async fn persist(
|
||||
let entity_count = entities.len();
|
||||
let relationship_count = relationships.len();
|
||||
|
||||
let ((), chunk_count) = tokio::try_join!(
|
||||
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships),
|
||||
store_vector_chunks(
|
||||
ctx.db,
|
||||
ctx.task_id.as_str(),
|
||||
&chunks,
|
||||
&ctx.pipeline_config.tuning
|
||||
)
|
||||
)?;
|
||||
debug!("Were storing chunks");
|
||||
let chunk_count = store_vector_chunks(
|
||||
ctx.db,
|
||||
ctx.task_id.as_str(),
|
||||
&chunks,
|
||||
&ctx.pipeline_config.tuning,
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("We stored chunks");
|
||||
store_graph_entities(ctx.db, &ctx.pipeline_config.tuning, entities, relationships).await?;
|
||||
|
||||
debug!("Stored graph entities");
|
||||
|
||||
ctx.db.store_item(text_content).await?;
|
||||
ctx.db.rebuild_indexes().await?;
|
||||
|
||||
debug!("stored item");
|
||||
rebuild_indexes(ctx.db).await?;
|
||||
|
||||
debug!(
|
||||
task_id = %ctx.task_id,
|
||||
@@ -268,17 +275,9 @@ async fn store_vector_chunks(
|
||||
let chunk_count = chunks.len();
|
||||
|
||||
let batch_size = tuning.chunk_insert_concurrency.max(1);
|
||||
for embedded in chunks {
|
||||
debug!(
|
||||
task_id = %task_id,
|
||||
chunk_id = %embedded.chunk.id,
|
||||
chunk_len = embedded.chunk.chunk.chars().count(),
|
||||
"chunk persisted"
|
||||
);
|
||||
}
|
||||
|
||||
for batch in chunks.chunks(batch_size) {
|
||||
store_chunk_batch(db, batch, tuning).await?;
|
||||
store_chunk_batch(db, batch, tuning, task_id).await?;
|
||||
}
|
||||
|
||||
Ok(chunk_count)
|
||||
@@ -294,14 +293,25 @@ async fn store_chunk_batch(
|
||||
db: &SurrealDbClient,
|
||||
batch: &[EmbeddedTextChunk],
|
||||
_tuning: &super::config::IngestionTuning,
|
||||
task_id: &str,
|
||||
) -> Result<(), AppError> {
|
||||
if batch.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for embedded in batch {
|
||||
TextChunk::store_with_embedding(embedded.chunk.clone(), embedded.embedding.clone(), db)
|
||||
.await?;
|
||||
TextChunk::store_with_embedding(
|
||||
embedded.chunk.to_owned(),
|
||||
embedded.embedding.to_owned(),
|
||||
db,
|
||||
)
|
||||
.await?;
|
||||
debug!(
|
||||
task_id = %task_id,
|
||||
chunk_id = %embedded.chunk.id,
|
||||
chunk_len = embedded.chunk.chunk.chars().count(),
|
||||
"chunk persisted"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -2,7 +2,9 @@ use api_router::{api_routes_v1, api_state::ApiState};
|
||||
use axum::{extract::FromRef, Router};
|
||||
use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient, indexes::ensure_runtime_indexes, store::StorageManager,
|
||||
db::SurrealDbClient,
|
||||
indexes::ensure_runtime_indexes,
|
||||
store::StorageManager,
|
||||
types::system_settings::SystemSettings,
|
||||
},
|
||||
utils::config::get_config,
|
||||
@@ -112,7 +114,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let _settings = SystemSettings::get_current(&worker_db)
|
||||
let settings = SystemSettings::get_current(&worker_db)
|
||||
.await
|
||||
.expect("failed to load system settings");
|
||||
|
||||
@@ -125,9 +127,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
// Create embedding provider for ingestion
|
||||
let embedding_provider = Arc::new(
|
||||
common::utils::embedding::EmbeddingProvider::new_fastembed(None)
|
||||
.await
|
||||
.expect("failed to create embedding provider"),
|
||||
common::utils::embedding::EmbeddingProvider::new_openai(
|
||||
openai_client.clone(),
|
||||
settings.embedding_model,
|
||||
settings.embedding_dimensions,
|
||||
)
|
||||
.expect("failed to create embedding provider"),
|
||||
);
|
||||
let ingestion_pipeline = Arc::new(
|
||||
IngestionPipeline::new(
|
||||
|
||||
Reference in New Issue
Block a user