mirror of
https://github.com/perstarkse/minne.git
synced 2026-04-17 14:39:45 +02:00
evals: v3, ebeddings at the side
additional indexes
This commit is contained in:
@@ -123,10 +123,6 @@ mod tests {
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn dummy_embedding() -> Vec<f32> {
|
||||
vec![0.0; 1536]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_single_field_score_for_name() {
|
||||
let namespace = "fts_test_ns";
|
||||
@@ -146,7 +142,6 @@ mod tests {
|
||||
"completely unrelated description".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
@@ -194,7 +189,6 @@ mod tests {
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
@@ -239,11 +233,10 @@ mod tests {
|
||||
let chunk = TextChunk::new(
|
||||
"source_chunk".into(),
|
||||
"GraphQL documentation reference".into(),
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk.clone())
|
||||
TextChunk::store_with_embedding(chunk.clone(), vec![0.0; 1536], &db)
|
||||
.await
|
||||
.expect("failed to insert chunk");
|
||||
|
||||
|
||||
@@ -171,7 +171,6 @@ mod tests {
|
||||
let source_id3 = "source789".to_string();
|
||||
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Entity with source_id1
|
||||
@@ -181,7 +180,6 @@ mod tests {
|
||||
"Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -192,7 +190,6 @@ mod tests {
|
||||
"Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -203,7 +200,6 @@ mod tests {
|
||||
"Description 3".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -214,7 +210,6 @@ mod tests {
|
||||
"Description 4".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -318,7 +313,6 @@ mod tests {
|
||||
|
||||
// Create some test entities
|
||||
let entity_type = KnowledgeEntityType::Document;
|
||||
let embedding = vec![0.1, 0.2, 0.3];
|
||||
let user_id = "user123".to_string();
|
||||
|
||||
// Create the central entity we'll query relationships for
|
||||
@@ -328,7 +322,6 @@ mod tests {
|
||||
"Central Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -339,7 +332,6 @@ mod tests {
|
||||
"Related Description 1".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -349,7 +341,6 @@ mod tests {
|
||||
"Related Description 2".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
@@ -360,7 +351,6 @@ mod tests {
|
||||
"Unrelated Description".to_string(),
|
||||
entity_type.clone(),
|
||||
None,
|
||||
embedding.clone(),
|
||||
user_id.clone(),
|
||||
);
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ pub mod graph;
|
||||
pub mod pipeline;
|
||||
pub mod reranking;
|
||||
pub mod scoring;
|
||||
pub mod vector;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
@@ -57,6 +56,7 @@ pub async fn retrieve_entities(
|
||||
pipeline::run_pipeline(
|
||||
db_client,
|
||||
openai_client,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
@@ -110,10 +110,10 @@ mod tests {
|
||||
|
||||
db.query(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding;
|
||||
DEFINE INDEX idx_embedding_text_chunk_embedding ON TABLE text_chunk_embedding FIELDS embedding HNSW DIMENSION 3;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding;
|
||||
DEFINE INDEX idx_embedding_knowledge_entity_embedding ON TABLE knowledge_entity_embedding FIELDS embedding HNSW DIMENSION 3;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
@@ -132,20 +132,18 @@ mod tests {
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk = TextChunk::new(
|
||||
entity.source_id.clone(),
|
||||
"Tokio uses cooperative scheduling for fairness.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
KnowledgeEntity::store_with_embedding(entity.clone(), entity_embedding_high(), &db)
|
||||
.await
|
||||
.expect("Failed to store entity");
|
||||
db.store_item(chunk.clone())
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk");
|
||||
|
||||
@@ -153,6 +151,7 @@ mod tests {
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
@@ -193,7 +192,6 @@ mod tests {
|
||||
"Explores async runtimes and scheduling strategies.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor = KnowledgeEntity::new(
|
||||
@@ -202,34 +200,31 @@ mod tests {
|
||||
"Details on Tokio's cooperative scheduler.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_low(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary.clone())
|
||||
KnowledgeEntity::store_with_embedding(primary.clone(), entity_embedding_high(), &db)
|
||||
.await
|
||||
.expect("Failed to store primary entity");
|
||||
db.store_item(neighbor.clone())
|
||||
KnowledgeEntity::store_with_embedding(neighbor.clone(), entity_embedding_low(), &db)
|
||||
.await
|
||||
.expect("Failed to store neighbor entity");
|
||||
|
||||
let primary_chunk = TextChunk::new(
|
||||
primary.source_id.clone(),
|
||||
"Rust async tasks use Tokio's cooperative scheduler.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor_chunk = TextChunk::new(
|
||||
neighbor.source_id.clone(),
|
||||
"Tokio's scheduler manages task fairness across executors.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary_chunk)
|
||||
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
db.store_item(neighbor_chunk)
|
||||
TextChunk::store_with_embedding(neighbor_chunk, chunk_embedding_secondary(), &db)
|
||||
.await
|
||||
.expect("Failed to store neighbor chunk");
|
||||
|
||||
@@ -249,6 +244,7 @@ mod tests {
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
@@ -270,6 +266,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
println!("{:?}", entities);
|
||||
|
||||
let neighbor_entry =
|
||||
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
|
||||
|
||||
@@ -293,20 +291,18 @@ mod tests {
|
||||
let chunk_one = TextChunk::new(
|
||||
"src_alpha".into(),
|
||||
"Tokio tasks execute on worker threads managed by the runtime.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk_two = TextChunk::new(
|
||||
"src_beta".into(),
|
||||
"Hyper utilizes Tokio to drive HTTP state machines efficiently.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk_one.clone())
|
||||
TextChunk::store_with_embedding(chunk_one.clone(), chunk_embedding_primary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk one");
|
||||
db.store_item(chunk_two.clone())
|
||||
TextChunk::store_with_embedding(chunk_two.clone(), chunk_embedding_secondary(), &db)
|
||||
.await
|
||||
.expect("Failed to store chunk two");
|
||||
|
||||
@@ -315,6 +311,7 @@ mod tests {
|
||||
let results = pipeline::run_pipeline_with_embedding(
|
||||
&db,
|
||||
&openai_client,
|
||||
None,
|
||||
test_embedding(),
|
||||
"tokio runtime worker behavior",
|
||||
user_id,
|
||||
|
||||
@@ -112,6 +112,7 @@ pub struct PipelineRunOutput<T> {
|
||||
pub async fn run_pipeline(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
@@ -137,6 +138,7 @@ pub async fn run_pipeline(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -153,6 +155,7 @@ pub async fn run_pipeline(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -169,6 +172,7 @@ pub async fn run_pipeline(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -185,6 +189,7 @@ pub async fn run_pipeline(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
None,
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -201,6 +206,7 @@ pub async fn run_pipeline(
|
||||
pub async fn run_pipeline_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
@@ -214,6 +220,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -230,6 +237,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -246,6 +254,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -262,6 +271,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -283,6 +293,7 @@ pub async fn run_pipeline_with_embedding(
|
||||
pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
@@ -296,6 +307,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -316,6 +328,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -340,6 +353,7 @@ pub async fn run_pipeline_with_embedding_with_metrics(
|
||||
pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
@@ -353,6 +367,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -373,6 +388,7 @@ pub async fn run_pipeline_with_embedding_with_diagnostics(
|
||||
driver,
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
Some(query_embedding),
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -419,6 +435,7 @@ async fn execute_strategy<D: StrategyDriver>(
|
||||
driver: D,
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
query_embedding: Option<Vec<f32>>,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
@@ -430,6 +447,7 @@ async fn execute_strategy<D: StrategyDriver>(
|
||||
Some(embedding) => PipelineContext::with_embedding(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
embedding,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
@@ -439,6 +457,7 @@ async fn execute_strategy<D: StrategyDriver>(
|
||||
None => PipelineContext::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text.to_owned(),
|
||||
user_id.to_owned(),
|
||||
config,
|
||||
|
||||
@@ -6,7 +6,7 @@ use common::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
utils::{embedding::generate_embedding, embedding::EmbeddingProvider},
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
@@ -24,10 +24,6 @@ use crate::{
|
||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||
FusionWeights, Scored,
|
||||
},
|
||||
vector::{
|
||||
find_chunk_snippets_by_vector_similarity_with_embedding,
|
||||
find_items_by_vector_similarity_with_embedding, ChunkSnippet,
|
||||
},
|
||||
RetrievedChunk, RetrievedEntity,
|
||||
};
|
||||
|
||||
@@ -43,6 +39,7 @@ use super::{
|
||||
pub struct PipelineContext<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
pub input_text: String,
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
@@ -51,7 +48,7 @@ pub struct PipelineContext<'a> {
|
||||
pub chunk_candidates: HashMap<String, Scored<TextChunk>>,
|
||||
pub filtered_entities: Vec<Scored<KnowledgeEntity>>,
|
||||
pub chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub revised_chunk_values: Vec<Scored<ChunkSnippet>>,
|
||||
pub revised_chunk_values: Vec<Scored<TextChunk>>,
|
||||
pub reranker: Option<RerankerLease>,
|
||||
pub diagnostics: Option<PipelineDiagnostics>,
|
||||
pub entity_results: Vec<RetrievedEntity>,
|
||||
@@ -63,6 +60,7 @@ impl<'a> PipelineContext<'a> {
|
||||
pub fn new(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
config: RetrievalConfig,
|
||||
@@ -71,6 +69,7 @@ impl<'a> PipelineContext<'a> {
|
||||
Self {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
@@ -91,6 +90,7 @@ impl<'a> PipelineContext<'a> {
|
||||
pub fn with_embedding(
|
||||
db_client: &'a SurrealDbClient,
|
||||
openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
query_embedding: Vec<f32>,
|
||||
input_text: String,
|
||||
user_id: String,
|
||||
@@ -100,6 +100,7 @@ impl<'a> PipelineContext<'a> {
|
||||
let mut ctx = Self::new(
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
config,
|
||||
@@ -299,8 +300,16 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
} else {
|
||||
debug!("Generating query embedding for hybrid retrieval");
|
||||
let embedding =
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?;
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await.map_err(|e| {
|
||||
AppError::InternalError(format!(
|
||||
"Failed to generate embedding with provider: {}",
|
||||
e
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
||||
};
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
@@ -315,19 +324,17 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
|
||||
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
let (vector_entity_results, vector_chunk_results, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
||||
KnowledgeEntity::vector_search(
|
||||
tuning.entity_vector_take,
|
||||
embedding.clone(),
|
||||
ctx.db_client,
|
||||
"knowledge_entity",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
TextChunk::vector_search(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
"text_chunk",
|
||||
&ctx.user_id,
|
||||
),
|
||||
find_items_by_fts(
|
||||
@@ -346,6 +353,15 @@ pub async fn collect_candidates(ctx: &mut PipelineContext<'_>) -> Result<(), App
|
||||
),
|
||||
)?;
|
||||
|
||||
let vector_entities: Vec<Scored<KnowledgeEntity>> = vector_entity_results
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.entity).with_vector_score(row.score))
|
||||
.collect();
|
||||
let vector_chunks: Vec<Scored<TextChunk>> = vector_chunk_results
|
||||
.into_iter()
|
||||
.map(|row| Scored::new(row.chunk).with_vector_score(row.score))
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
vector_entities = vector_entities.len(),
|
||||
vector_chunks = vector_chunks.len(),
|
||||
@@ -419,14 +435,15 @@ pub async fn expand_graph(ctx: &mut PipelineContext<'_>) -> Result<(), AppError>
|
||||
}
|
||||
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
let neighbor_id = neighbor.id.clone();
|
||||
if neighbor_id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * tuning.graph_score_decay);
|
||||
let entry = ctx
|
||||
.entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.entry(neighbor_id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
@@ -490,8 +507,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
||||
|
||||
ctx.filtered_entities = filtered_entities;
|
||||
|
||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||
|
||||
let mut chunk_results: Vec<Scored<TextChunk>> =
|
||||
ctx.chunk_candidates.values().cloned().collect();
|
||||
sort_by_fused_desc(&mut chunk_results);
|
||||
@@ -507,7 +522,6 @@ pub async fn attach_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
weights,
|
||||
&query_embedding,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -579,13 +593,23 @@ pub async fn collect_vector_chunks(ctx: &mut PipelineContext<'_>) -> Result<(),
|
||||
debug!("Collecting vector chunk candidates for revised strategy");
|
||||
let embedding = ctx.ensure_embedding()?.clone();
|
||||
let tuning = &ctx.config.tuning;
|
||||
let mut vector_chunks = find_chunk_snippets_by_vector_similarity_with_embedding(
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let mut vector_chunks: Vec<Scored<TextChunk>> = TextChunk::vector_search(
|
||||
tuning.chunk_vector_take,
|
||||
embedding,
|
||||
ctx.db_client,
|
||||
&ctx.user_id,
|
||||
)
|
||||
.await?;
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let mut scored = Scored::new(row.chunk).with_vector_score(row.score);
|
||||
let fused = fuse_scores(&scored.scores, weights);
|
||||
scored.update_fused(fused);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
if ctx.diagnostics_enabled() {
|
||||
ctx.record_collect_candidates(CollectCandidatesStats {
|
||||
@@ -617,7 +641,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let documents = build_snippet_rerank_documents(
|
||||
let documents = build_chunk_rerank_documents(
|
||||
&ctx.revised_chunk_values,
|
||||
ctx.config.tuning.rerank_keep_top.max(1),
|
||||
);
|
||||
@@ -628,11 +652,7 @@ pub async fn rerank_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError
|
||||
|
||||
match reranker.rerank(&ctx.input_text, documents).await {
|
||||
Ok(results) if !results.is_empty() => {
|
||||
apply_snippet_rerank_results(
|
||||
&mut ctx.revised_chunk_values,
|
||||
&ctx.config.tuning,
|
||||
results,
|
||||
);
|
||||
apply_chunk_rerank_results(&mut ctx.revised_chunk_values, &ctx.config.tuning, results);
|
||||
}
|
||||
Ok(_) => debug!("Chunk reranker returned no results; retaining original order"),
|
||||
Err(err) => warn!(
|
||||
@@ -649,7 +669,7 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Assembling chunk-only retrieval results");
|
||||
let mut chunk_values = std::mem::take(&mut ctx.revised_chunk_values);
|
||||
let question_terms = extract_keywords(&ctx.input_text);
|
||||
rank_snippet_chunks_by_combined_score(
|
||||
rank_chunks_by_combined_score(
|
||||
&mut chunk_values,
|
||||
&question_terms,
|
||||
ctx.config.tuning.lexical_match_weight,
|
||||
@@ -662,12 +682,9 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
|
||||
ctx.chunk_results = chunk_values
|
||||
.into_iter()
|
||||
.map(|chunk| {
|
||||
let text_chunk = snippet_into_text_chunk(chunk.item, &ctx.user_id);
|
||||
RetrievedChunk {
|
||||
chunk: text_chunk,
|
||||
score: chunk.fused,
|
||||
}
|
||||
.map(|chunk| RetrievedChunk {
|
||||
chunk: chunk.item,
|
||||
score: chunk.fused,
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -691,7 +708,6 @@ pub fn assemble_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Assembling final retrieved entities");
|
||||
let tuning = &ctx.config.tuning;
|
||||
let query_embedding = ctx.ensure_embedding()?.clone();
|
||||
let question_terms = extract_keywords(&ctx.input_text);
|
||||
|
||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
@@ -704,9 +720,8 @@ pub fn assemble(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
|
||||
for chunk_list in chunk_by_source.values_mut() {
|
||||
chunk_list.sort_by(|a, b| {
|
||||
let sim_a = cosine_similarity(&query_embedding, &a.item.embedding);
|
||||
let sim_b = cosine_similarity(&query_embedding, &b.item.embedding);
|
||||
sim_b.partial_cmp(&sim_a).unwrap_or(Ordering::Equal)
|
||||
// No base-table embeddings; order by fused score only.
|
||||
b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal)
|
||||
});
|
||||
}
|
||||
|
||||
@@ -930,7 +945,6 @@ async fn enrich_chunks_from_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
query_embedding: &[f32],
|
||||
) -> Result<(), AppError> {
|
||||
let mut source_ids: HashSet<String> = HashSet::new();
|
||||
for entity in entities {
|
||||
@@ -964,16 +978,7 @@ async fn enrich_chunks_from_entities(
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let similarity = cosine_similarity(query_embedding, &chunk.embedding);
|
||||
|
||||
entry.scores.vector = Some(
|
||||
entry
|
||||
.scores
|
||||
.vector
|
||||
.unwrap_or(0.0)
|
||||
.max(entity_score * 0.8)
|
||||
.max(similarity),
|
||||
);
|
||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
entry.item = chunk;
|
||||
@@ -982,24 +987,6 @@ async fn enrich_chunks_from_entities(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cosine_similarity(query: &[f32], embedding: &[f32]) -> f32 {
|
||||
if query.is_empty() || embedding.is_empty() || query.len() != embedding.len() {
|
||||
return 0.0;
|
||||
}
|
||||
let mut dot = 0.0f32;
|
||||
let mut norm_q = 0.0f32;
|
||||
let mut norm_e = 0.0f32;
|
||||
for (q, e) in query.iter().zip(embedding.iter()) {
|
||||
dot += q * e;
|
||||
norm_q += q * q;
|
||||
norm_e += e * e;
|
||||
}
|
||||
if norm_q == 0.0 || norm_e == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
dot / (norm_q.sqrt() * norm_e.sqrt())
|
||||
}
|
||||
|
||||
fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usize) -> Vec<String> {
|
||||
if ctx.filtered_entities.is_empty() {
|
||||
return Vec::new();
|
||||
@@ -1050,10 +1037,7 @@ fn build_rerank_documents(ctx: &PipelineContext<'_>, max_chunks_per_entity: usiz
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn build_snippet_rerank_documents(
|
||||
chunks: &[Scored<ChunkSnippet>],
|
||||
max_chunks: usize,
|
||||
) -> Vec<String> {
|
||||
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
||||
chunks
|
||||
.iter()
|
||||
.take(max_chunks)
|
||||
@@ -1124,8 +1108,8 @@ fn apply_rerank_results(ctx: &mut PipelineContext<'_>, results: Vec<RerankResult
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_snippet_rerank_results(
|
||||
chunks: &mut Vec<Scored<ChunkSnippet>>,
|
||||
fn apply_chunk_rerank_results(
|
||||
chunks: &mut Vec<Scored<TextChunk>>,
|
||||
tuning: &RetrievalTuning,
|
||||
results: Vec<RerankResult>,
|
||||
) {
|
||||
@@ -1133,7 +1117,7 @@ fn apply_snippet_rerank_results(
|
||||
return;
|
||||
}
|
||||
|
||||
let mut remaining: Vec<Option<Scored<ChunkSnippet>>> =
|
||||
let mut remaining: Vec<Option<Scored<TextChunk>>> =
|
||||
std::mem::take(chunks).into_iter().map(Some).collect();
|
||||
|
||||
let raw_scores: Vec<f32> = results.iter().map(|r| r.score).collect();
|
||||
@@ -1146,7 +1130,7 @@ fn apply_snippet_rerank_results(
|
||||
clamp_unit(tuning.rerank_blend_weight)
|
||||
};
|
||||
|
||||
let mut reranked: Vec<Scored<ChunkSnippet>> = Vec::with_capacity(remaining.len());
|
||||
let mut reranked: Vec<Scored<TextChunk>> = Vec::with_capacity(remaining.len());
|
||||
for (result, normalized) in results.into_iter().zip(normalized_scores.into_iter()) {
|
||||
if let Some(slot) = remaining.get_mut(result.index) {
|
||||
if let Some(mut candidate) = slot.take() {
|
||||
@@ -1217,32 +1201,6 @@ fn extract_keywords(text: &str) -> Vec<String> {
|
||||
terms
|
||||
}
|
||||
|
||||
fn rank_snippet_chunks_by_combined_score(
|
||||
candidates: &mut [Scored<ChunkSnippet>],
|
||||
question_terms: &[String],
|
||||
lexical_weight: f32,
|
||||
) {
|
||||
if lexical_weight > 0.0 && !question_terms.is_empty() {
|
||||
for candidate in candidates.iter_mut() {
|
||||
let lexical = lexical_overlap_score(question_terms, &candidate.item.chunk);
|
||||
let combined = clamp_unit(candidate.fused + lexical_weight * lexical);
|
||||
candidate.update_fused(combined);
|
||||
}
|
||||
}
|
||||
candidates.sort_by(|a, b| b.fused.partial_cmp(&a.fused).unwrap_or(Ordering::Equal));
|
||||
}
|
||||
|
||||
fn snippet_into_text_chunk(snippet: ChunkSnippet, user_id: &str) -> TextChunk {
|
||||
let mut chunk = TextChunk::new(
|
||||
snippet.source_id.clone(),
|
||||
snippet.chunk,
|
||||
Vec::new(),
|
||||
user_id.to_owned(),
|
||||
);
|
||||
chunk.id = snippet.id;
|
||||
chunk
|
||||
}
|
||||
|
||||
fn lexical_overlap_score(terms: &[String], haystack: &str) -> f32 {
|
||||
if terms.is_empty() {
|
||||
return 0.0;
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{file_info::deserialize_flexible_id, StoredObject},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use surrealdb::sql::Thing;
|
||||
|
||||
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
/// This function generates embeddings for the input text, constructs a query to find the closest matches in the database,
|
||||
/// and then deserializes the results into the specified type `T`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `take` - The number of items to retrieve from the database.
|
||||
/// * `input_text` - The text to generate embeddings for.
|
||||
/// * `db_client` - The SurrealDB client to use for querying the database.
|
||||
/// * `table` - The table to query in the database.
|
||||
/// * `openai_client` - The OpenAI client to use for generating embeddings.
|
||||
/// * 'user_id`- The user id of the current user.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of type `T` containing the closest matches to the input text. Returns a `ProcessingError` if an error occurs.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: usize,
|
||||
input_text: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
|
||||
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DistanceRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
distance: Option<f32>,
|
||||
}
|
||||
|
||||
pub async fn find_items_by_vector_similarity_with_embedding<T>(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
|
||||
let closest_query = format!(
|
||||
"SELECT id, vector::distance::knn() AS distance \
|
||||
FROM {table} \
|
||||
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||
LIMIT $limit",
|
||||
table = table,
|
||||
take = take,
|
||||
embedding = embedding_literal
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(closest_query)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let distance_rows: Vec<DistanceRow> = response.take(0)?;
|
||||
|
||||
if distance_rows.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
|
||||
let thing_ids: Vec<Thing> = ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((table, id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut items_response = db_client
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", table.to_owned()))
|
||||
.bind(("things", thing_ids.clone()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let items: Vec<T> = items_response.take(0)?;
|
||||
|
||||
let mut item_map: HashMap<String, T> = items
|
||||
.into_iter()
|
||||
.map(|item| (item.get_id().to_owned(), item))
|
||||
.collect();
|
||||
|
||||
let mut min_distance = f32::MAX;
|
||||
let mut max_distance = f32::MIN;
|
||||
|
||||
for row in &distance_rows {
|
||||
if let Some(distance) = row.distance {
|
||||
if distance.is_finite() {
|
||||
if distance < min_distance {
|
||||
min_distance = distance;
|
||||
}
|
||||
if distance > max_distance {
|
||||
max_distance = distance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let normalize = min_distance.is_finite()
|
||||
&& max_distance.is_finite()
|
||||
&& (max_distance - min_distance).abs() > f32::EPSILON;
|
||||
|
||||
let mut scored = Vec::with_capacity(distance_rows.len());
|
||||
for row in distance_rows {
|
||||
if let Some(item) = item_map.remove(&row.id) {
|
||||
let similarity = row
|
||||
.distance
|
||||
.map(|distance| {
|
||||
if normalize {
|
||||
let span = max_distance - min_distance;
|
||||
if span.abs() < f32::EPSILON {
|
||||
1.0
|
||||
} else {
|
||||
let normalized = 1.0 - ((distance - min_distance) / span);
|
||||
clamp_unit(normalized)
|
||||
}
|
||||
} else {
|
||||
distance_to_similarity(distance)
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
scored.push(Scored::new(item).with_vector_score(similarity));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChunkSnippet {
|
||||
pub id: String,
|
||||
pub source_id: String,
|
||||
pub chunk: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChunkDistanceRow {
|
||||
distance: Option<f32>,
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
pub id: String,
|
||||
pub source_id: String,
|
||||
pub chunk: String,
|
||||
}
|
||||
|
||||
pub async fn find_chunk_snippets_by_vector_similarity_with_embedding(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<ChunkSnippet>>, AppError> {
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
|
||||
let closest_query = format!(
|
||||
"SELECT id, source_id, chunk, vector::distance::knn() AS distance \
|
||||
FROM text_chunk \
|
||||
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||
LIMIT $limit",
|
||||
take = take,
|
||||
embedding = embedding_literal
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(closest_query)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let rows: Vec<ChunkDistanceRow> = response.take(0)?;
|
||||
|
||||
let mut scored = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let similarity = row.distance.map(distance_to_similarity).unwrap_or_default();
|
||||
scored.push(
|
||||
Scored::new(ChunkSnippet {
|
||||
id: row.id,
|
||||
source_id: row.source_id,
|
||||
chunk: row.chunk,
|
||||
})
|
||||
.with_vector_score(similarity),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
Reference in New Issue
Block a user