evals: v3, ebeddings at the side

additional indexes
This commit is contained in:
Per Stark
2025-11-26 15:00:55 +01:00
parent 226b2db43a
commit 030f0fc17d
63 changed files with 3859 additions and 1124 deletions

View File

@@ -123,10 +123,6 @@ mod tests {
};
use uuid::Uuid;
fn dummy_embedding() -> Vec<f32> {
vec![0.0; 1536]
}
#[tokio::test]
async fn fts_preserves_single_field_score_for_name() {
let namespace = "fts_test_ns";
@@ -146,7 +142,6 @@ mod tests {
"completely unrelated description".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
@@ -194,7 +189,6 @@ mod tests {
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
dummy_embedding(),
user_id.into(),
);
@@ -239,11 +233,10 @@ mod tests {
let chunk = TextChunk::new(
"source_chunk".into(),
"GraphQL documentation reference".into(),
dummy_embedding(),
user_id.into(),
);
db.store_item(chunk.clone())
TextChunk::store_with_embedding(chunk.clone(), vec![0.0; 1536], &db)
.await
.expect("failed to insert chunk");

View File

@@ -171,7 +171,6 @@ mod tests {
let source_id3 = "source789".to_string();
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3];
let user_id = "user123".to_string();
// Entity with source_id1
@@ -181,7 +180,6 @@ mod tests {
"Description 1".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -192,7 +190,6 @@ mod tests {
"Description 2".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -203,7 +200,6 @@ mod tests {
"Description 3".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -214,7 +210,6 @@ mod tests {
"Description 4".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -318,7 +313,6 @@ mod tests {
// Create some test entities
let entity_type = KnowledgeEntityType::Document;
let embedding = vec![0.1, 0.2, 0.3];
let user_id = "user123".to_string();
// Create the central entity we'll query relationships for
@@ -328,7 +322,6 @@ mod tests {
"Central Description".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -339,7 +332,6 @@ mod tests {
"Related Description 1".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -349,7 +341,6 @@ mod tests {
"Related Description 2".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);
@@ -360,7 +351,6 @@ mod tests {
"Unrelated Description".to_string(),
entity_type.clone(),
None,
embedding.clone(),
user_id.clone(),
);

View File

@@ -5,7 +5,6 @@ pub mod graph;
pub mod pipeline;
pub mod reranking;
pub mod scoring;
pub mod vector;
use common::{
error::AppError,
@@ -57,6 +56,7 @@ pub async fn retrieve_entities(
pipeline::run_pipeline(
db_client,
openai_client,
None,
input_text,
user_id,
config,
@@ -110,10 +110,10 @@ mod tests {
db.query(
"BEGIN TRANSACTION;
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 3;
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding;
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 3;
COMMIT TRANSACTION;",
)
.await
@@ -132,20 +132,18 @@ mod tests {
"Detailed notes about async runtimes".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let chunk = TextChunk::new(
entity.source_id.clone(),
"Tokio uses cooperative scheduling for fairness.".into(),
chunk_embedding_primary(),
user_id.into(),
);
db.store_item(entity.clone())
KnowledgeEntity::store_with_embedding(entity.clone(), entity_embedding_high(), &db)
.await
.expect("Failed to store entity");
db.store_item(chunk.clone())
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
.await
.expect("Failed to store chunk");
@@ -153,6 +151,7 @@ mod tests {
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
None,
test_embedding(),
"Rust concurrency async tasks",
user_id,
@@ -193,7 +192,6 @@ mod tests {
"Explores async runtimes and scheduling strategies.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_high(),
user_id.into(),
);
let neighbor = KnowledgeEntity::new(
@@ -202,34 +200,31 @@ mod tests {
"Details on Tokio's cooperative scheduler.".into(),
KnowledgeEntityType::Document,
None,
entity_embedding_low(),
user_id.into(),
);
db.store_item(primary.clone())
KnowledgeEntity::store_with_embedding(primary.clone(), entity_embedding_high(), &db)
.await
.expect("Failed to store primary entity");
db.store_item(neighbor.clone())
KnowledgeEntity::store_with_embedding(neighbor.clone(), entity_embedding_low(), &db)
.await
.expect("Failed to store neighbor entity");
let primary_chunk = TextChunk::new(
primary.source_id.clone(),
"Rust async tasks use Tokio's cooperative scheduler.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let neighbor_chunk = TextChunk::new(
neighbor.source_id.clone(),
"Tokio's scheduler manages task fairness across executors.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(primary_chunk)
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db)
.await
.expect("Failed to store primary chunk");
db.store_item(neighbor_chunk)
TextChunk::store_with_embedding(neighbor_chunk, chunk_embedding_secondary(), &db)
.await
.expect("Failed to store neighbor chunk");
@@ -249,6 +244,7 @@ mod tests {
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
None,
test_embedding(),
"Rust concurrency async tasks",
user_id,
@@ -270,6 +266,8 @@ mod tests {
}
}
println!("{:?}", entities);
let neighbor_entry =
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
@@ -293,20 +291,18 @@ mod tests {
let chunk_one = TextChunk::new(
"src_alpha".into(),
"Tokio tasks execute on worker threads managed by the runtime.".into(),
chunk_embedding_primary(),
user_id.into(),
);
let chunk_two = TextChunk::new(
"src_beta".into(),
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
chunk_embedding_secondary(),
user_id.into(),
);
db.store_item(chunk_one.clone())
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db)
.await
.expect("Failed to store chunk one");
db.store_item(chunk_two.clone())
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db)
.await
.expect("Failed to store chunk two");
@@ -315,6 +311,7 @@ mod tests {
let results = pipeline::run_pipeline_with_embedding(
&db,
&openai_client,
None,
test_embedding(),
"tokio runtime worker behavior",
user_id,

View File

@@ -112,6 +112,7 @@ pub struct PipelineRunOutput<T> {
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
@@ -137,6 +138,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -153,6 +155,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -169,6 +172,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -185,6 +189,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -201,6 +206,7 @@ pub async fn run_pipeline(
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
@@ -214,6 +220,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -230,6 +237,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -246,6 +254,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -262,6 +271,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -283,6 +293,7 @@ pub async fn run_pipeline_with_embedding(
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
@@ -296,6 +307,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -316,6 +328,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -340,6 +353,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
@@ -353,6 +367,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -373,6 +388,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -419,6 +435,7 @@ async fn execute_strategy<D: StrategyDriver>(
driver: D,
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Option<Vec<f32>>,
input_text: &str,
user_id: &str,
@@ -430,6 +447,7 @@ async fn execute_strategy<D: StrategyDriver>(
Some(embedding) => PipelineContext::with_embedding(
db_client,
openai_client,
embedding_provider,
embedding,
input_text.to_owned(),
user_id.to_owned(),
@@ -439,6 +457,7 @@ async fn execute_strategy<D: StrategyDriver>(
None => PipelineContext::new(
db_client,
openai_client,
embedding_provider,
input_text.to_owned(),
user_id.to_owned(),
config,

View File

@@ -6,7 +6,7 @@ use common::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
utils::embedding::generate_embedding,
utils::{embedding::generate_embedding, embedding::EmbeddingProvider},
};
use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
@@ -24,10 +24,6 @@ use crate::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
},
vector::{
find_chunk_snippets_by_vector_similarity_with_embedding,
find_items_by_vector_similarity_with_embedding, ChunkSnippet,
},
RetrievedChunk, RetrievedEntity,
};
@@ -43,6 +39,7 @@ use super::{
pub struct PipelineContext<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a EmbeddingProvider>,
pub input_text: String,
pub user_id: String,
pub config: RetrievalConfig,
@@ -51,7 +48,7 @@ pub struct PipelineContext<'a> {
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub revised_chunk_values: Vec<Scored<ChunkSnippet>>,
pub revised_chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<PipelineDiagnostics>,
pub entity_results: Vec<RetrievedEntity>,
@@ -63,6 +60,7 @@ impl<'a> PipelineContext<'a> {
pub fn new(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
input_text: String,
user_id: String,
config: RetrievalConfig,
@@ -71,6 +69,7 @@ impl<'a> PipelineContext<'a> {
Self {
db_client,
openai_client,
embedding_provider,
input_text,
user_id,
config,
@@ -91,6 +90,7 @@ impl<'a> PipelineContext<'a> {
pub fn with_embedding(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: String,
user_id: String,
@@ -100,6 +100,7 @@ impl<'a> PipelineContext<'a> {
let mut ctx = Self::new(
db_client,
openai_client,
embedding_provider,
input_text,
user_id,
config,
@@ -299,8 +300,16 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Reusing cached query embedding for hybrid retrieval");
} else {
debug!("Generating query embedding for hybrid retrieval");
let embedding =
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
let embedding = if let Some(provider) = ctx.embedding_provider {
provider.embed(&ctx.input_text).await.map_err(|e| {
AppError::InternalError(format!(
"Failed to generate embedding with provider: {}",
e
))
})?
} else {
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
};
ctx.query_embedding = Some(embedding);
}
@@ -315,19 +324,17 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
let weights = FusionWeights::default();
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
find_items_by_vector_similarity_with_embedding(
let (vector_entity_results, vector_chunk_results, mut fts_entities, mut fts_chunks) = tokio::try_join!(
KnowledgeEntity::vector_search(
tuning.entity_vector_take,
embedding.clone(),
ctx.db_client,
"knowledge_entity",
&ctx.user_id,
),
find_items_by_vector_similarity_with_embedding(
TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
"text_chunk",
&ctx.user_id,
),
find_items_by_fts(
@@ -346,6 +353,15 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
),
)?;
let vector_entities: Vec<Scored<KnowledgeEntity>> = vector_entity_results
.into_iter()
.map(|row| Scored::new(row.entity).with_vector_score(row.score))
.collect();
let vector_chunks: Vec<Scored<TextChunk>> = vector_chunk_results
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.collect();
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
@@ -419,14 +435,15 @@ pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError>
}
for neighbor in neighbors {
if neighbor.id == seed.id {
let neighbor_id = neighbor.id.clone();
if neighbor_id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.entry(neighbor_id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
@@ -490,8 +507,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
ctx.filtered_entities = filtered_entities;
let query_embedding = ctx.ensure_embedding()?.clone();
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
@@ -507,7 +522,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
ctx.db_client,
&ctx.user_id,
weights,
&query_embedding,
)
.await?;
@@ -579,13 +593,23 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
debug!("Collecting vector chunk candidates for revised strategy");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding(
let weights = FusionWeights::default();
let mut vector_chunks: Vec<Scored<TextChunk>> = TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
&ctx.user_id,
)
.await?;
.await?
.into_iter()
.map(|row| {
let mut scored = Scored::new(row.chunk).with_vector_score(row.score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
scored
})
.collect();
if ctx.diagnostics_enabled() {
ctx.record_collect_candidates(CollectCandidatesStats {
@@ -617,7 +641,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
return Ok(());
};
let documents = build_snippet_rerank_documents(
let documents = build_chunk_rerank_documents(
&ctx.revised_chunk_values,
ctx.config.tuning.rerank_keep_top.max(1),
);
@@ -628,11 +652,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_snippet_rerank_results(
&mut ctx.revised_chunk_values,
&ctx.config.tuning,
results,
);
apply_chunk_rerank_results(&mut ctx.revised_chunk_values, &ctx.config.tuning, results);
}
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
Err(err) => warn!(
@@ -649,7 +669,7 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling chunk-only retrieval results");
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
let question_terms = extract_keywords(&ctx.input_text);
rank_snippet_chunks_by_combined_score(
rank_chunks_by_combined_score(
&mut chunk_values,
&question_terms,
ctx.config.tuning.lexical_match_weight,
@@ -662,12 +682,9 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
ctx.chunk_results = chunk_values
.into_iter()
.map(|chunk| {
let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id);
RetrievedChunk {
chunk: text_chunk,
score: chunk.fused,
}
.map(|chunk| RetrievedChunk {
chunk: chunk.item,
score: chunk.fused,
})
.collect();
@@ -691,7 +708,6 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
let query_embedding = ctx.ensure_embedding()?.clone();
let question_terms = extract_keywords(&ctx.input_text);
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
@@ -704,9 +720,8 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
for chunk_list in chunk_by_source.values_mut() {
chunk_list.sort_by(|a, b| {
let sim_a = cosine_similarity(&query_embedding, &a.item.embedding);
let sim_b = cosine_similarity(&query_embedding, &b.item.embedding);
sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal)
// No base-table embeddings; order by fused score only.
b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)
});
}
@@ -930,7 +945,6 @@ async fn enrich_chunks_from_entities(
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
query_embedding: &[f32],
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
@@ -964,16 +978,7 @@ async fn enrich_chunks_from_entities(
.copied()
.unwrap_or(0.0);
let similarity = cosine_similarity(query_embedding, &chunk.embedding);
entry.scores.vector = Some(
entry
.scores
.vector
.unwrap_or(0.0)
.max(entity_score * 0.8)
.max(similarity),
);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
@@ -982,24 +987,6 @@ async fn enrich_chunks_from_entities(
Ok(())
}
fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 {
if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_q = 0.0f32;
let mut norm_e = 0.0f32;
for (q, e) in query.iter().zip(embedding.iter()) {
dot += q * e;
norm_q += q * q;
norm_e += e * e;
}
if norm_q == 0.0 || norm_e == 0.0 {
return 0.0;
}
dot / (norm_q.sqrt() * norm_e.sqrt())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() {
return Vec::new();
@@ -1050,10 +1037,7 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz
.collect()
}
fn build_snippet_rerank_documents(
chunks: &[Scored<ChunkSnippet>],
max_chunks: usize,
) -> Vec<String> {
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
chunks
.iter()
.take(max_chunks)
@@ -1124,8 +1108,8 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
}
}
fn apply_snippet_rerank_results(
chunks: &mut Vec<Scored<ChunkSnippet>>,
fn apply_chunk_rerank_results(
chunks: &mut Vec<Scored<TextChunk>>,
tuning: &RetrievalTuning,
results: Vec<RerankResult>,
) {
@@ -1133,7 +1117,7 @@ fn apply_snippet_rerank_results(
return;
}
let mut remaining: Vec<Option<Scored<ChunkSnippet>>> =
let mut remaining: Vec<Option<Scored<TextChunk>>> =
std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
@@ -1146,7 +1130,7 @@ fn apply_snippet_rerank_results(
clamp_unit(tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<ChunkSnippet>> = Vec::with_capacity(remaining.len());
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
@@ -1217,32 +1201,6 @@ fn extract_keywords(text: &str) -> Vec<String> {
terms
}
fn rank_snippet_chunks_by_combined_score(
candidates: &mut [Scored<ChunkSnippet>],
question_terms: &[String],
lexical_weight: f32,
) {
if lexical_weight > 0.0 && !question_terms.is_empty() {
for candidate in candidates.iter_mut() {
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
candidate.update_fused(combined);
}
}
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
}
fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk {
let mut chunk = TextChunk::new(
snippet.source_id.clone(),
snippet.chunk,
Vec::new(),
user_id.to_owned(),
);
chunk.id = snippet.id;
chunk
}
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
if terms.is_empty() {
return 0.0;

View File

@@ -1,218 +0,0 @@
use std::collections::HashMap;
use common::{
error::AppError,
storage::{
db::SurrealDbClient,
types::{file_info::deserialize_flexible_id, StoredObject},
},
utils::embedding::generate_embedding,
};
use serde::Deserialize;
use surrealdb::sql::Thing;
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
/// Compares vectors and retrieves a number of items from the specified table.
///
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
/// and then deserializes the results into the specified type `T`.
///
/// # Arguments
///
/// * `take` - The number of items to retrieve from the database.
/// * `input_text` - The text to generate embeddings for.
/// * `db_client` - The SurrealDB client to use for querying the database.
/// * `table` - The table to query in the database.
/// * `openai_client` - The OpenAI client to use for generating embeddings.
/// * 'user_id`- The user id of the current user.
///
/// # Returns
///
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
///
/// # Type Parameters
///
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
pub async fn find_items_by_vector_similarity<T>(
take: usize,
input_text: &str,
db_client: &SurrealDbClient,
table: &str,
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
// Generate embeddings
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
.await
}
#[derive(Debug, Deserialize)]
struct DistanceRow {
#[serde(deserialize_with = "deserialize_flexible_id")]
id: String,
distance: Option<f32>,
}
pub async fn find_items_by_vector_similarity_with_embedding<T>(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
table: &str,
user_id: &str,
) -> Result<Vec<Scored<T>>, AppError>
where
T: for<'de> serde::Deserialize<'de> + StoredObject,
{
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, vector::distance::knn() AS distance \
FROM {table} \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
table = table,
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let distance_rows: Vec<DistanceRow> = response.take(0)?;
if distance_rows.is_empty() {
return Ok(Vec::new());
}
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
let thing_ids: Vec<Thing> = ids
.iter()
.map(|id| Thing::from((table, id.as_str())))
.collect();
let mut items_response = db_client
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
.bind(("table", table.to_owned()))
.bind(("things", thing_ids.clone()))
.bind(("user_id", user_id.to_owned()))
.await?;
let items: Vec<T> = items_response.take(0)?;
let mut item_map: HashMap<String, T> = items
.into_iter()
.map(|item| (item.get_id().to_owned(), item))
.collect();
let mut min_distance = f32::MAX;
let mut max_distance = f32::MIN;
for row in &distance_rows {
if let Some(distance) = row.distance {
if distance.is_finite() {
if distance < min_distance {
min_distance = distance;
}
if distance > max_distance {
max_distance = distance;
}
}
}
}
let normalize = min_distance.is_finite()
&& max_distance.is_finite()
&& (max_distance - min_distance).abs() > f32::EPSILON;
let mut scored = Vec::with_capacity(distance_rows.len());
for row in distance_rows {
if let Some(item) = item_map.remove(&row.id) {
let similarity = row
.distance
.map(|distance| {
if normalize {
let span = max_distance - min_distance;
if span.abs() < f32::EPSILON {
1.0
} else {
let normalized = 1.0 - ((distance - min_distance) / span);
clamp_unit(normalized)
}
} else {
distance_to_similarity(distance)
}
})
.unwrap_or_default();
scored.push(Scored::new(item).with_vector_score(similarity));
}
}
Ok(scored)
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChunkSnippet {
pub id: String,
pub source_id: String,
pub chunk: String,
}
#[derive(Debug, Deserialize)]
struct ChunkDistanceRow {
distance: Option<f32>,
#[serde(deserialize_with = "deserialize_flexible_id")]
pub id: String,
pub source_id: String,
pub chunk: String,
}
pub async fn find_chunk_snippets_by_vector_similarity_with_embedding(
take: usize,
query_embedding: Vec<f32>,
db_client: &SurrealDbClient,
user_id: &str,
) -> Result<Vec<Scored<ChunkSnippet>>, AppError> {
let embedding_literal = serde_json::to_string(&query_embedding)
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
let closest_query = format!(
"SELECT id, source_id, chunk, vector::distance::knn() AS distance \
FROM text_chunk \
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
LIMIT $limit",
take = take,
embedding = embedding_literal
);
let mut response = db_client
.query(closest_query)
.bind(("user_id", user_id.to_owned()))
.bind(("limit", take as i64))
.await?;
let rows: Vec<ChunkDistanceRow> = response.take(0)?;
let mut scored = Vec::with_capacity(rows.len());
for row in rows {
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
scored.push(
Scored::new(ChunkSnippet {
id: row.id,
source_id: row.source_id,
chunk: row.chunk,
})
.with_vector_score(similarity),
);
}
Ok(scored)
}