evals: v3, ebeddings at the side

additional indexes
This commit is contained in:
Per Stark
2025-11-26 15:00:55 +01:00
parent 226b2db43a
commit 030f0fc17d
63 changed files with 3859 additions and 1124 deletions
+19
View File
@@ -112,6 +112,7 @@ pub struct PipelineRunOutput<T> {
pub async fn run_pipeline(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
input_text: &str,
user_id: &str,
config: RetrievalConfig,
@@ -137,6 +138,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -153,6 +155,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -169,6 +172,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -185,6 +189,7 @@ pub async fn run_pipeline(
driver,
db_client,
openai_client,
embedding_provider,
None,
input_text,
user_id,
@@ -201,6 +206,7 @@ pub async fn run_pipeline(
pub async fn run_pipeline_with_embedding(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
@@ -214,6 +220,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -230,6 +237,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -246,6 +254,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -262,6 +271,7 @@ pub async fn run_pipeline_with_embedding(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -283,6 +293,7 @@ pub async fn run_pipeline_with_embedding(
pub async fn run_pipeline_with_embedding_with_metrics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
@@ -296,6 +307,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -316,6 +328,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -340,6 +353,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
pub async fn run_pipeline_with_embedding_with_diagnostics(
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: &str,
user_id: &str,
@@ -353,6 +367,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -373,6 +388,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
driver,
db_client,
openai_client,
embedding_provider,
Some(query_embedding),
input_text,
user_id,
@@ -419,6 +435,7 @@ async fn execute_strategy<D: StrategyDriver>(
driver: D,
db_client: &SurrealDbClient,
openai_client: &Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
query_embedding: Option<Vec<f32>>,
input_text: &str,
user_id: &str,
@@ -430,6 +447,7 @@ async fn execute_strategy<D: StrategyDriver>(
Some(embedding) => PipelineContext::with_embedding(
db_client,
openai_client,
embedding_provider,
embedding,
input_text.to_owned(),
user_id.to_owned(),
@@ -439,6 +457,7 @@ async fn execute_strategy<D: StrategyDriver>(
None => PipelineContext::new(
db_client,
openai_client,
embedding_provider,
input_text.to_owned(),
user_id.to_owned(),
config,
+58 -100
View File
@@ -6,7 +6,7 @@ use common::{
db::SurrealDbClient,
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
},
utils::embedding::generate_embedding,
utils::{embedding::generate_embedding, embedding::EmbeddingProvider},
};
use fastembed::RerankResult;
use futures::{stream::FuturesUnordered, StreamExt};
@@ -24,10 +24,6 @@ use crate::{
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
FusionWeights, Scored,
},
vector::{
find_chunk_snippets_by_vector_similarity_with_embedding,
find_items_by_vector_similarity_with_embedding, ChunkSnippet,
},
RetrievedChunk, RetrievedEntity,
};
@@ -43,6 +39,7 @@ use super::{
pub struct PipelineContext<'a> {
pub db_client: &'a SurrealDbClient,
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
pub embedding_provider: Option<&'a EmbeddingProvider>,
pub input_text: String,
pub user_id: String,
pub config: RetrievalConfig,
@@ -51,7 +48,7 @@ pub struct PipelineContext<'a> {
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
pub chunk_values: Vec<Scored<TextChunk>>,
pub revised_chunk_values: Vec<Scored<ChunkSnippet>>,
pub revised_chunk_values: Vec<Scored<TextChunk>>,
pub reranker: Option<RerankerLease>,
pub diagnostics: Option<PipelineDiagnostics>,
pub entity_results: Vec<RetrievedEntity>,
@@ -63,6 +60,7 @@ impl<'a> PipelineContext<'a> {
pub fn new(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
input_text: String,
user_id: String,
config: RetrievalConfig,
@@ -71,6 +69,7 @@ impl<'a> PipelineContext<'a> {
Self {
db_client,
openai_client,
embedding_provider,
input_text,
user_id,
config,
@@ -91,6 +90,7 @@ impl<'a> PipelineContext<'a> {
pub fn with_embedding(
db_client: &'a SurrealDbClient,
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
embedding_provider: Option<&'a EmbeddingProvider>,
query_embedding: Vec<f32>,
input_text: String,
user_id: String,
@@ -100,6 +100,7 @@ impl<'a> PipelineContext<'a> {
let mut ctx = Self::new(
db_client,
openai_client,
embedding_provider,
input_text,
user_id,
config,
@@ -299,8 +300,16 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Reusing cached query embedding for hybrid retrieval");
} else {
debug!("Generating query embedding for hybrid retrieval");
let embedding =
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
let embedding = if let Some(provider) = ctx.embedding_provider {
provider.embed(&ctx.input_text).await.map_err(|e| {
AppError::InternalError(format!(
"Failed to generate embedding with provider: {}",
e
))
})?
} else {
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
};
ctx.query_embedding = Some(embedding);
}
@@ -315,19 +324,17 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
let weights = FusionWeights::default();
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
find_items_by_vector_similarity_with_embedding(
let (vector_entity_results, vector_chunk_results, mut fts_entities, mut fts_chunks) = tokio::try_join!(
KnowledgeEntity::vector_search(
tuning.entity_vector_take,
embedding.clone(),
ctx.db_client,
"knowledge_entity",
&ctx.user_id,
),
find_items_by_vector_similarity_with_embedding(
TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
"text_chunk",
&ctx.user_id,
),
find_items_by_fts(
@@ -346,6 +353,15 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
),
)?;
let vector_entities: Vec<Scored<KnowledgeEntity>> = vector_entity_results
.into_iter()
.map(|row| Scored::new(row.entity).with_vector_score(row.score))
.collect();
let vector_chunks: Vec<Scored<TextChunk>> = vector_chunk_results
.into_iter()
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
.collect();
debug!(
vector_entities = vector_entities.len(),
vector_chunks = vector_chunks.len(),
@@ -419,14 +435,15 @@ pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError>
}
for neighbor in neighbors {
if neighbor.id == seed.id {
let neighbor_id = neighbor.id.clone();
if neighbor_id == seed.id {
continue;
}
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
let entry = ctx
.entity_candidates
.entry(neighbor.id.clone())
.entry(neighbor_id.clone())
.or_insert_with(|| Scored::new(neighbor.clone()));
entry.item = neighbor;
@@ -490,8 +507,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
ctx.filtered_entities = filtered_entities;
let query_embedding = ctx.ensure_embedding()?.clone();
let mut chunk_results: Vec<Scored<TextChunk>> =
ctx.chunk_candidates.values().cloned().collect();
sort_by_fused_desc(&mut chunk_results);
@@ -507,7 +522,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
ctx.db_client,
&ctx.user_id,
weights,
&query_embedding,
)
.await?;
@@ -579,13 +593,23 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
debug!("Collecting vector chunk candidates for revised strategy");
let embedding = ctx.ensure_embedding()?.clone();
let tuning = &ctx.config.tuning;
let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding(
let weights = FusionWeights::default();
let mut vector_chunks: Vec<Scored<TextChunk>> = TextChunk::vector_search(
tuning.chunk_vector_take,
embedding,
ctx.db_client,
&ctx.user_id,
)
.await?;
.await?
.into_iter()
.map(|row| {
let mut scored = Scored::new(row.chunk).with_vector_score(row.score);
let fused = fuse_scores(&scored.scores, weights);
scored.update_fused(fused);
scored
})
.collect();
if ctx.diagnostics_enabled() {
ctx.record_collect_candidates(CollectCandidatesStats {
@@ -617,7 +641,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
return Ok(());
};
let documents = build_snippet_rerank_documents(
let documents = build_chunk_rerank_documents(
&ctx.revised_chunk_values,
ctx.config.tuning.rerank_keep_top.max(1),
);
@@ -628,11 +652,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
match reranker.rerank(&ctx.input_text, documents).await {
Ok(results) if !results.is_empty() => {
apply_snippet_rerank_results(
&mut ctx.revised_chunk_values,
&ctx.config.tuning,
results,
);
apply_chunk_rerank_results(&mut ctx.revised_chunk_values, &ctx.config.tuning, results);
}
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
Err(err) => warn!(
@@ -649,7 +669,7 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling chunk-only retrieval results");
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
let question_terms = extract_keywords(&ctx.input_text);
rank_snippet_chunks_by_combined_score(
rank_chunks_by_combined_score(
&mut chunk_values,
&question_terms,
ctx.config.tuning.lexical_match_weight,
@@ -662,12 +682,9 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
ctx.chunk_results = chunk_values
.into_iter()
.map(|chunk| {
let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id);
RetrievedChunk {
chunk: text_chunk,
score: chunk.fused,
}
.map(|chunk| RetrievedChunk {
chunk: chunk.item,
score: chunk.fused,
})
.collect();
@@ -691,7 +708,6 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
debug!("Assembling final retrieved entities");
let tuning = &ctx.config.tuning;
let query_embedding = ctx.ensure_embedding()?.clone();
let question_terms = extract_keywords(&ctx.input_text);
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
@@ -704,9 +720,8 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
for chunk_list in chunk_by_source.values_mut() {
chunk_list.sort_by(|a, b| {
let sim_a = cosine_similarity(&query_embedding, &a.item.embedding);
let sim_b = cosine_similarity(&query_embedding, &b.item.embedding);
sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal)
// No base-table embeddings; order by fused score only.
b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)
});
}
@@ -930,7 +945,6 @@ async fn enrich_chunks_from_entities(
db_client: &SurrealDbClient,
user_id: &str,
weights: FusionWeights,
query_embedding: &[f32],
) -> Result<(), AppError> {
let mut source_ids: HashSet<String> = HashSet::new();
for entity in entities {
@@ -964,16 +978,7 @@ async fn enrich_chunks_from_entities(
.copied()
.unwrap_or(0.0);
let similarity = cosine_similarity(query_embedding, &chunk.embedding);
entry.scores.vector = Some(
entry
.scores
.vector
.unwrap_or(0.0)
.max(entity_score * 0.8)
.max(similarity),
);
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
let fused = fuse_scores(&entry.scores, weights);
entry.update_fused(fused);
entry.item = chunk;
@@ -982,24 +987,6 @@ async fn enrich_chunks_from_entities(
Ok(())
}
fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 {
if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_q = 0.0f32;
let mut norm_e = 0.0f32;
for (q, e) in query.iter().zip(embedding.iter()) {
dot += q * e;
norm_q += q * q;
norm_e += e * e;
}
if norm_q == 0.0 || norm_e == 0.0 {
return 0.0;
}
dot / (norm_q.sqrt() * norm_e.sqrt())
}
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
if ctx.filtered_entities.is_empty() {
return Vec::new();
@@ -1050,10 +1037,7 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz
.collect()
}
fn build_snippet_rerank_documents(
chunks: &[Scored<ChunkSnippet>],
max_chunks: usize,
) -> Vec<String> {
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
chunks
.iter()
.take(max_chunks)
@@ -1124,8 +1108,8 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
}
}
fn apply_snippet_rerank_results(
chunks: &mut Vec<Scored<ChunkSnippet>>,
fn apply_chunk_rerank_results(
chunks: &mut Vec<Scored<TextChunk>>,
tuning: &RetrievalTuning,
results: Vec<RerankResult>,
) {
@@ -1133,7 +1117,7 @@ fn apply_snippet_rerank_results(
return;
}
let mut remaining: Vec<Option<Scored<ChunkSnippet>>> =
let mut remaining: Vec<Option<Scored<TextChunk>>> =
std::mem::take(chunks).into_iter().map(Some).collect();
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
@@ -1146,7 +1130,7 @@ fn apply_snippet_rerank_results(
clamp_unit(tuning.rerank_blend_weight)
};
let mut reranked: Vec<Scored<ChunkSnippet>> = Vec::with_capacity(remaining.len());
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
if let Some(slot) = remaining.get_mut(result.index) {
if let Some(mut candidate) = slot.take() {
@@ -1217,32 +1201,6 @@ fn extract_keywords(text: &str) -> Vec<String> {
terms
}
fn rank_snippet_chunks_by_combined_score(
candidates: &mut [Scored<ChunkSnippet>],
question_terms: &[String],
lexical_weight: f32,
) {
if lexical_weight > 0.0 && !question_terms.is_empty() {
for candidate in candidates.iter_mut() {
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
candidate.update_fused(combined);
}
}
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
}
fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk {
let mut chunk = TextChunk::new(
snippet.source_id.clone(),
snippet.chunk,
Vec::new(),
user_id.to_owned(),
);
chunk.id = snippet.id;
chunk
}
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
if terms.is_empty() {
return 0.0;