diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 732067f..ed0694e 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -258,7 +258,7 @@ impl KnowledgeEntity { /// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores. pub async fn vector_search( take: usize, - query_embedding: Vec, + query_embedding: &[f32], db: &SurrealDbClient, user_id: &str, ) -> Result, AppError> { @@ -286,7 +286,7 @@ impl KnowledgeEntity { let mut response = db .query(&sql) - .bind(("embedding", query_embedding)) + .bind(("embedding", query_embedding.to_vec())) .bind(("user_id", user_id.to_string())) .await .map_err(AppError::from)?; @@ -408,7 +408,7 @@ impl KnowledgeEntity { ) }) .collect(); - let embeddings = provider.embed_batch(inputs).await?; + let embeddings = provider.embed_batch(&inputs).await?; if embeddings.len() != batch.len() { return Err(AppError::internal(format!( "embedding batch returned {} vectors for {} entities", @@ -817,7 +817,7 @@ mod tests { .await .expect("Failed to redefine index length"); - let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") + let results = KnowledgeEntity::vector_search(5, &[0.1, 0.2, 0.3], &db, "user") .await .expect("vector search"); assert!(results.is_empty()); @@ -878,7 +878,7 @@ mod tests { .with_context(|| "fetch embedding".to_string())?; assert!(fetched_emb.is_some()); - let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) + let results = KnowledgeEntity::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id) .await .with_context(|| "vector search".to_string())?; @@ -965,7 +965,7 @@ mod tests { .with_context(|| "get embedding e2".to_string())? .is_some()); - let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) + let results = KnowledgeEntity::vector_search(2, &[0.0, 1.0, 0.0], &db, &user_id) .await .with_context(|| "vector search".to_string())?; @@ -1030,7 +1030,7 @@ mod tests { .await .with_context(|| "delete entity".to_string())?; - let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) + let results = KnowledgeEntity::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id) .await .with_context(|| "search should succeed even with orphans".to_string())?; diff --git a/common/src/storage/types/knowledge_relationship.rs b/common/src/storage/types/knowledge_relationship.rs index 64e383a..6c3b4fa 100644 --- a/common/src/storage/types/knowledge_relationship.rs +++ b/common/src/storage/types/knowledge_relationship.rs @@ -42,12 +42,24 @@ impl KnowledgeRelationship { } } - pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> { + pub async fn store_relationship(self, db_client: &SurrealDbClient) -> Result<(), AppError> { User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client) .await?; User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client) .await?; + let Self { + id, + in_, + out, + metadata: + RelationshipMetadata { + user_id, + source_id, + relationship_type, + }, + } = self; + db_client .client .query( @@ -62,12 +74,12 @@ impl KnowledgeRelationship { metadata.relationship_type = $relationship_type; COMMIT TRANSACTION;"#, ) - .bind(("rel_id", self.id.clone())) - .bind(("in_id", self.in_.clone())) - .bind(("out_id", self.out.clone())) - .bind(("user_id", self.metadata.user_id.clone())) - .bind(("source_id", self.metadata.source_id.clone())) - .bind(("relationship_type", self.metadata.relationship_type.clone())) + .bind(("rel_id", id)) + .bind(("in_id", in_)) + .bind(("out_id", out)) + .bind(("user_id", user_id)) + .bind(("source_id", source_id)) + .bind(("relationship_type", relationship_type)) .await .map_err(AppError::from)? .check() @@ -230,13 +242,14 @@ mod tests { source_id.clone(), relationship_type, ); + let relationship_id = relationship.id.clone(); relationship .store_relationship(&db) .await .with_context(|| "Failed to store relationship".to_string())?; - let persisted = get_relationship_by_id(&relationship.id, &db) + let persisted = get_relationship_by_id(&relationship_id, &db) .await .expect("Relationship should be retrievable by id"); assert_eq!(persisted.in_, entity1_id); @@ -296,6 +309,7 @@ mod tests { "source123'; DELETE FROM relates_to; --".to_string(), "references'; UPDATE user SET admin = true; --".to_string(), ); + let relationship_id = relationship.id.clone(); relationship .store_relationship(&db) @@ -305,7 +319,7 @@ mod tests { let mut res = db .client .query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)") - .bind(("id", relationship.id.clone())) + .bind(("id", relationship_id)) .await .expect("query relationship by id failed"); let rows: Vec = res.take(0).expect("take rows"); @@ -338,6 +352,7 @@ mod tests { source_id.clone(), relationship_type, ); + let relationship_id = relationship.id.clone(); relationship .store_relationship(&db) @@ -357,7 +372,7 @@ mod tests { "Relationship should exist before deletion" ); - KnowledgeRelationship::delete_relationship_by_id(&relationship.id, user_id, &db) + KnowledgeRelationship::delete_relationship_by_id(&relationship_id, user_id, &db) .await .with_context(|| "Failed to delete relationship by ID".to_string())?; @@ -391,6 +406,7 @@ mod tests { source_id, "references".to_string(), ); + let relationship_id = relationship.id.clone(); relationship .store_relationship(&db) @@ -409,7 +425,7 @@ mod tests { ); let result = KnowledgeRelationship::delete_relationship_by_id( - &relationship.id, + &relationship_id, "different-user", &db, ) @@ -472,6 +488,9 @@ mod tests { different_source_id.clone(), "mentions".to_string(), ); + let relationship1_id = relationship1.id.clone(); + let relationship2_id = relationship2.id.clone(); + let different_relationship_id = different_relationship.id.clone(); relationship1 .store_relationship(&db) @@ -508,9 +527,9 @@ mod tests { .await .with_context(|| "Failed to delete relationships by source_id".to_string())?; - let result1 = get_relationship_by_id(&relationship1.id, &db).await; - let result2 = get_relationship_by_id(&relationship2.id, &db).await; - let different_result = get_relationship_by_id(&different_relationship.id, &db).await; + let result1 = get_relationship_by_id(&relationship1_id, &db).await; + let result2 = get_relationship_by_id(&relationship2_id, &db).await; + let different_result = get_relationship_by_id(&different_relationship_id, &db).await; assert!(result1.is_none(), "Relationship 1 should be deleted"); assert!(result2.is_none(), "Relationship 2 should be deleted"); @@ -548,6 +567,8 @@ mod tests { shared_source.to_string(), "references".to_string(), ); + let rel_a_id = rel_a.id.clone(); + let rel_b_id = rel_b.id.clone(); rel_a.store_relationship(&db).await?; rel_b.store_relationship(&db).await?; @@ -555,8 +576,8 @@ mod tests { KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db) .await?; - assert!(get_relationship_by_id(&rel_a.id, &db).await.is_none()); - assert!(get_relationship_by_id(&rel_b.id, &db).await.is_some()); + assert!(get_relationship_by_id(&rel_a_id, &db).await.is_none()); + assert!(get_relationship_by_id(&rel_b_id, &db).await.is_some()); Ok(()) } @@ -586,6 +607,8 @@ mod tests { "other_source".to_string(), "contains".to_string(), ); + let safe_relationship_id = safe_relationship.id.clone(); + let other_relationship_id = other_relationship.id.clone(); safe_relationship .store_relationship(&db) @@ -604,8 +627,8 @@ mod tests { .await .expect("delete call should succeed"); - let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await; - let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await; + let remaining_safe = get_relationship_by_id(&safe_relationship_id, &db).await; + let remaining_other = get_relationship_by_id(&other_relationship_id, &db).await; assert!(remaining_safe.is_some(), "Safe relationship should remain"); assert!( diff --git a/common/src/storage/types/text_chunk.rs b/common/src/storage/types/text_chunk.rs index 4aae9ed..73c8306 100644 --- a/common/src/storage/types/text_chunk.rs +++ b/common/src/storage/types/text_chunk.rs @@ -107,7 +107,7 @@ impl TextChunk { /// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings. pub async fn vector_search( take: usize, - query_embedding: Vec, + query_embedding: &[f32], db: &SurrealDbClient, user_id: &str, ) -> Result, AppError> { @@ -137,7 +137,7 @@ impl TextChunk { let mut response = db .query(&sql) - .bind(("embedding", query_embedding)) + .bind(("embedding", query_embedding.to_vec())) .bind(("user_id", user_id.to_string())) .await .map_err(AppError::from)?; @@ -273,7 +273,7 @@ impl TextChunk { let mut processed = 0usize; for batch in all_chunks.chunks(RE_EMBED_BATCH_SIZE) { let inputs: Vec = batch.iter().map(|chunk| chunk.chunk.clone()).collect(); - let embeddings = provider.embed_batch(inputs).await?; + let embeddings = provider.embed_batch(&inputs).await?; if embeddings.len() != batch.len() { return Err(AppError::internal(format!( "embedding batch returned {} vectors for {} chunks", @@ -720,7 +720,7 @@ mod tests { .with_context(|| "redefine index".to_string())?; let results: Vec = - TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user") + TextChunk::vector_search(5, &[0.1, 0.2, 0.3], &db, "user") .await .with_context(|| "vector_search".to_string())?; assert!(results.is_empty()); @@ -756,7 +756,7 @@ mod tests { .with_context(|| "store".to_string())?; let results: Vec = - TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) + TextChunk::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id) .await .with_context(|| "vector_search".to_string())?; @@ -796,7 +796,7 @@ mod tests { .with_context(|| "store chunk2".to_string())?; let results: Vec = - TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id) + TextChunk::vector_search(2, &[0.0, 1.0, 0.0], &db, &user_id) .await .with_context(|| "vector_search".to_string())?; @@ -987,7 +987,7 @@ mod tests { .await .with_context(|| "delete chunk".to_string())?; - let results = TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id) + let results = TextChunk::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id) .await .with_context(|| "search should succeed even with orphans".to_string())?; diff --git a/common/src/utils/embedding.rs b/common/src/utils/embedding.rs index cf2dfec..ae38373 100644 --- a/common/src/utils/embedding.rs +++ b/common/src/utils/embedding.rs @@ -372,17 +372,17 @@ impl EmbeddingProvider { /// /// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data. /// Returns an empty `Vec` when `texts` is empty. - pub async fn embed_batch(&self, texts: Vec) -> Result>, EmbeddingError> { + pub async fn embed_batch(&self, texts: &[String]) -> Result>, EmbeddingError> { match &self.inner { EmbeddingInner::Hashed { dimension } => Ok(texts - .into_iter() - .map(|text| hashed_embedding(&text, *dimension)) + .iter() + .map(|text| hashed_embedding(text, *dimension)) .collect()), EmbeddingInner::FastEmbed { pool, .. } => { if texts.is_empty() { return Ok(Vec::new()); } - run_fastembed(pool, texts).await + run_fastembed(pool, texts.to_vec()).await } EmbeddingInner::OpenAI { client, @@ -395,7 +395,7 @@ impl EmbeddingProvider { let request = CreateEmbeddingRequestArgs::default() .model(model.clone()) - .input(texts) + .input(texts.to_vec()) .dimensions(*dimensions) .build()?; diff --git a/evaluations/src/types.rs b/evaluations/src/types.rs index 50cdb5d..b54d369 100644 --- a/evaluations/src/types.rs +++ b/evaluations/src/types.rs @@ -205,7 +205,7 @@ impl EvaluationCandidate { entity_description: Some(entity.entity.description.clone()), entity_category, score: entity.score, - chunks: entity.chunks, + chunks: entity.chunks.as_ref().clone(), } } diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index 2af580e..fe0f3c9 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -40,7 +40,7 @@ use crate::{ template_with_headers, TemplateResponse, TemplateResult, ResponseResult, }, }, - utils::pagination::{paginate_items, Pagination}, + utils::pagination::{paginate_items, paginate_slice, Pagination}, }; use url::form_urlencoded; @@ -196,16 +196,18 @@ pub async fn create_knowledge_entity( let source_id = format!("manual::{}", Uuid::new_v4()); let new_entity = KnowledgeEntity::new( source_id, - name.clone(), - description.clone(), + name, + description, entity_type, None, user.id.clone(), ); + let new_entity_id = new_entity.id.clone(); - KnowledgeEntity::store_with_embedding(new_entity.clone(), embedding, &state.db).await?; + KnowledgeEntity::store_with_embedding(new_entity, embedding, &state.db).await?; let relationship_type = relationship_type_or_default(form.relationship_type.as_deref()); + let user_id = user.id.clone(); debug!("form: {:?}", form); if !form.relationship_ids.is_empty() { @@ -217,7 +219,7 @@ pub async fn create_knowledge_entity( let mut unique_ids: HashSet = HashSet::new(); for target_id in form.relationship_ids { - if target_id == new_entity.id { + if target_id == new_entity_id { continue; } if !valid_ids.contains(&target_id) { @@ -228,10 +230,10 @@ pub async fn create_knowledge_entity( } let relationship = KnowledgeRelationship::new( - new_entity.id.clone(), + new_entity_id.clone(), target_id, - user.id.clone(), - format!("manual::{}", new_entity.id), + user_id.clone(), + format!("manual::{new_entity_id}"), relationship_type.clone(), ); relationship.store_relationship(&state.db).await?; @@ -385,7 +387,7 @@ async fn suggest_related_entities( let suggestion_min_rrf_score = 1.0 / (tuning.chunk_rrf_k + 1.0); let (vector_rows, fts_rows) = tokio::try_join!( - KnowledgeEntity::vector_search(take, embedding, db, user_id), + KnowledgeEntity::vector_search(take, &embedding, db, user_id), async { if fts_enabled { KnowledgeEntity::fts_search(take, &fts_query, db, user_id).await @@ -480,10 +482,13 @@ fn build_relationship_options( options } -fn build_relationship_table_data( - entities: Vec, +fn build_relationship_rows( relationships: Vec, -) -> RelationshipTableData { +) -> ( + Vec, + Vec, + String, +) { let relationship_type_options = collect_relationship_type_options(&relationships); let mut frequency: HashMap = HashMap::new(); let relationships = relationships @@ -503,7 +508,25 @@ fn build_relationship_table_data( .collect(); let default_relationship_type = frequency .into_iter() - .max_by_key(|(_, count)| *count).map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label); + .max_by_key(|(_, count)| *count) + .map_or_else( + || DEFAULT_RELATIONSHIP_TYPE.to_string(), + |(label, _)| label, + ); + + ( + relationships, + relationship_type_options, + default_relationship_type, + ) +} + +fn build_relationship_table_data( + entities: Vec, + relationships: Vec, +) -> RelationshipTableData { + let (relationships, relationship_type_options, default_relationship_type) = + build_relationship_rows(relationships); RelationshipTableData { entities, @@ -532,7 +555,7 @@ async fn build_knowledge_base_data( }; let (visible_entities, pagination) = - paginate_items(entities.clone(), params.page, KNOWLEDGE_ENTITIES_PER_PAGE); + paginate_slice(&entities, params.page, KNOWLEDGE_ENTITIES_PER_PAGE); let page_query = { let mut serializer = form_urlencoded::Serializer::new(String::new()); @@ -551,17 +574,15 @@ async fn build_knowledge_base_data( }; let relationships = User::get_knowledge_relationships(&user.id, &state.db).await?; - let entity_id_set: HashSet = entities.iter().map(|e| e.id.clone()).collect(); + let entity_id_set: HashSet<&str> = entities.iter().map(|e| e.id.as_str()).collect(); let filtered_relationships: Vec = relationships .into_iter() - .filter(|rel| entity_id_set.contains(&rel.in_) && entity_id_set.contains(&rel.out)) + .filter(|rel| { + entity_id_set.contains(rel.in_.as_str()) && entity_id_set.contains(rel.out.as_str()) + }) .collect(); - let RelationshipTableData { - entities: _, - relationships, - relationship_type_options, - default_relationship_type, - } = build_relationship_table_data(entities.clone(), filtered_relationships); + let (relationships, relationship_type_options, default_relationship_type) = + build_relationship_rows(filtered_relationships); Ok(KnowledgeBaseData { entities, diff --git a/html-router/src/utils/pagination.rs b/html-router/src/utils/pagination.rs index 54a8c6f..b20ffd2 100644 --- a/html-router/src/utils/pagination.rs +++ b/html-router/src/utils/pagination.rs @@ -57,6 +57,47 @@ impl Pagination { } } +/// Returns a cloned page slice and pagination metadata without consuming the source list. +pub fn paginate_slice( + items: &[T], + requested_page: Option, + per_page: usize, +) -> (Vec, Pagination) { + let per_page = per_page.max(1); + let total_items = items.len(); + let total_pages = if total_items == 0 { + 0 + } else { + total_items + .saturating_sub(1) + .checked_div(per_page) + .unwrap_or(0) + .saturating_add(1) + }; + + let mut current_page = requested_page.unwrap_or(1); + if current_page == 0 { + current_page = 1; + } + if total_pages > 0 { + current_page = current_page.min(total_pages); + } else { + current_page = 1; + } + + let offset = if total_pages == 0 { + 0 + } else { + per_page.saturating_mul(current_page.saturating_sub(1)) + }; + + let page_items: Vec = items.iter().skip(offset).take(per_page).cloned().collect(); + let page_len = page_items.len(); + let pagination = Pagination::new(current_page, per_page, total_items, total_pages, page_len); + + (page_items, pagination) +} + /// Returns the items for the requested page along with pagination metadata. pub fn paginate_items( items: Vec, @@ -96,7 +137,7 @@ pub fn paginate_items( #[cfg(test)] mod tests { - use super::paginate_items; + use super::{paginate_items, paginate_slice}; #[test] fn paginates_basic_case() { @@ -128,6 +169,16 @@ mod tests { assert_eq!(meta.end_index, 0); } + #[test] + fn paginate_slice_clones_only_page_items() { + let items: Vec<_> = (1..=25).collect(); + let (page, meta) = paginate_slice(&items, Some(2), 10); + + assert_eq!(page, vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20]); + assert_eq!(items.len(), 25); + assert_eq!(meta.current_page, 2); + } + #[test] fn clamps_page_to_bounds() { let items: Vec<_> = (1..=5).collect(); diff --git a/ingestion-pipeline/src/pipeline/services.rs b/ingestion-pipeline/src/pipeline/services.rs index 365e6da..b0c084c 100644 --- a/ingestion-pipeline/src/pipeline/services.rs +++ b/ingestion-pipeline/src/pipeline/services.rs @@ -261,23 +261,24 @@ impl PipelineServices for DefaultPipelineServices { // Embed all chunks of this document in one batch: a single lock acquisition and one // blocking hop, letting the backend batch the inference internally. + let batch_len = chunk_candidates.len(); let embeddings = self .embedding_provider - .embed_batch(chunk_candidates.clone()) + .embed_batch(&chunk_candidates) .await .map_err(|e| { AppError::InternalError(format!("FastEmbed embedding for chunks failed: {e}")) })?; - if embeddings.len() != chunk_candidates.len() { + if embeddings.len() != batch_len { return Err(AppError::InternalError(format!( "embedding batch returned {} vectors for {} chunks", embeddings.len(), - chunk_candidates.len() + batch_len ))); } - let mut chunks = Vec::with_capacity(chunk_candidates.len()); + let mut chunks = Vec::with_capacity(batch_len); for (chunk_text, embedding) in chunk_candidates.into_iter().zip(embeddings) { let chunk_struct = TextChunk::new( content.id().to_string(), diff --git a/ingestion-pipeline/src/pipeline/tests.rs b/ingestion-pipeline/src/pipeline/tests.rs index 5c8887d..c702027 100644 --- a/ingestion-pipeline/src/pipeline/tests.rs +++ b/ingestion-pipeline/src/pipeline/tests.rs @@ -91,10 +91,10 @@ impl MockServices { similar_entities: vec![RetrievedEntity { entity: retrieved_entity, score: 0.8, - chunks: vec![RetrievedChunk { + chunks: std::sync::Arc::new(vec![RetrievedChunk { chunk: retrieved_chunk, score: 0.7, - }], + }]), }], analysis, chunk_embedding: vec![0.3; TEST_EMBEDDING_DIM], diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 009b7a1..d27eb32 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -5,6 +5,8 @@ pub mod query; pub mod reranking; pub mod scoring; +use std::sync::Arc; + use common::{ error::AppError, storage::{ @@ -52,7 +54,7 @@ pub struct RetrievedChunk { pub struct RetrievedEntity { pub entity: KnowledgeEntity, pub score: f32, - pub chunks: Vec, + pub chunks: Arc>, } /// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities. diff --git a/retrieval-pipeline/src/pipeline/stages.rs b/retrieval-pipeline/src/pipeline/stages.rs index 857eb36..e752284 100644 --- a/retrieval-pipeline/src/pipeline/stages.rs +++ b/retrieval-pipeline/src/pipeline/stages.rs @@ -4,7 +4,7 @@ use common::{ storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk}, }; use fastembed::RerankResult; -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Write, sync::Arc}; use tracing::{debug, instrument, warn}; use crate::{ @@ -106,7 +106,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { #[instrument(level = "trace", skip_all)] pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> { debug!("Collecting chunk candidates via vector and FTS search"); - let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone(); + let embedding = ctx.ensure_embedding().map_err(|e| *e)?; let tuning = &ctx.config.tuning; let fts_take = tuning.chunk_fts_take; let (fts_query, fts_token_count) = normalize_fts_terms(&ctx.input_text); @@ -233,12 +233,16 @@ pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppEr let mut best_score: HashMap = HashMap::new(); for scored in &ctx.chunk_values { - let source = scored.item.source_id.clone(); - let attached = chunks_by_source.entry(source.clone()).or_default(); - if attached.is_empty() { - source_order.push(source.clone()); - best_score.insert(source.clone(), scored.fused); + let source_id = &scored.item.source_id; + let is_new_source = !chunks_by_source.contains_key(source_id); + if is_new_source { + source_order.push(source_id.clone()); + best_score.insert(source_id.clone(), scored.fused); } + + let attached = chunks_by_source + .entry(source_id.clone()) + .or_default(); if attached.len() < max_chunks { attached.push(RetrievedChunk { chunk: scored.item.clone(), @@ -247,6 +251,11 @@ pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppEr } } + let chunks_by_source: HashMap>> = chunks_by_source + .into_iter() + .map(|(source, chunks)| (source, Arc::new(chunks))) + .collect(); + let entities = KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?; @@ -264,12 +273,15 @@ pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppEr continue; }; let score = best_score.get(source).copied().unwrap_or(0.0); - let chunks = chunks_by_source.get(source).cloned().unwrap_or_default(); + let chunks = chunks_by_source + .get(source) + .cloned() + .unwrap_or_else(|| Arc::new(Vec::new())); for entity in entities { results.push(RetrievedEntity { entity, score, - chunks: chunks.clone(), + chunks: Arc::clone(&chunks), }); } } @@ -328,17 +340,26 @@ where } fn build_chunk_rerank_documents(chunks: &[Scored], max_chunks: usize) -> Vec { - chunks - .iter() - .take(max_chunks) - .map(|chunk| { - format!( - "Source: {}\nChunk:\n{}", - chunk.item.source_id, - chunk.item.chunk.trim() - ) - }) - .collect() + let take = chunks.len().min(max_chunks); + let mut documents = Vec::with_capacity(take); + let mut buffer = String::with_capacity(512); + + for chunk in chunks.iter().take(max_chunks) { + buffer.clear(); + let _ = write!( + buffer, + "Source: {}\nChunk:\n{}", + chunk.item.source_id, + chunk.item.chunk.trim() + ); + let next_capacity = buffer.capacity().max(512); + documents.push(std::mem::replace( + &mut buffer, + String::with_capacity(next_capacity), + )); + } + + documents } fn apply_chunk_rerank_results( diff --git a/retrieval-pipeline/src/scoring.rs b/retrieval-pipeline/src/scoring.rs index 1c51e2b..849d5ed 100644 --- a/retrieval-pipeline/src/scoring.rs +++ b/retrieval-pipeline/src/scoring.rs @@ -1,4 +1,7 @@ -use std::{cmp::Ordering, collections::HashMap}; +use std::{ + cmp::Ordering, + collections::{hash_map::Entry, HashMap}, +}; use common::storage::types::StoredObject; @@ -119,7 +122,7 @@ pub fn reciprocal_rank_fusion( config: RrfConfig, ) -> Vec> where - T: StoredObject + Clone, + T: StoredObject, { let mut merged: HashMap> = HashMap::new(); let k = if config.k <= 0.0 { 60.0 } else { config.k }; @@ -146,19 +149,30 @@ where for (rank, candidate) in vector_ranked.into_iter().enumerate() { let id = candidate.item.id().to_owned(); - let entry = merged - .entry(id.clone()) - .or_insert_with(|| Scored::new(candidate.item.clone())); + let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); + let contribution = vector_weight / (k + rank_f32 + 1.0); - if let Some(score) = candidate.scores.vector { - let existing = entry.scores.vector.unwrap_or(f32::MIN); - if score > existing { - entry.scores.vector = Some(score); + match merged.entry(id) { + Entry::Occupied(mut occupied) => { + let entry = occupied.get_mut(); + if let Some(score) = candidate.scores.vector { + let existing = entry.scores.vector.unwrap_or(f32::MIN); + if score > existing { + entry.scores.vector = Some(score); + } + } + entry.item = candidate.item; + entry.fused += contribution; + } + Entry::Vacant(vacant) => { + let mut scored = Scored::new(candidate.item); + if let Some(score) = candidate.scores.vector { + scored.scores.vector = Some(score); + } + scored.fused = contribution; + vacant.insert(scored); } } - entry.item = candidate.item; - let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); - entry.fused += vector_weight / (k + rank_f32 + 1.0); } } @@ -174,19 +188,30 @@ where for (rank, candidate) in fts_ranked.into_iter().enumerate() { let id = candidate.item.id().to_owned(); - let entry = merged - .entry(id.clone()) - .or_insert_with(|| Scored::new(candidate.item.clone())); + let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); + let contribution = fts_weight / (k + rank_f32 + 1.0); - if let Some(score) = candidate.scores.fts { - let existing = entry.scores.fts.unwrap_or(f32::MIN); - if score > existing { - entry.scores.fts = Some(score); + match merged.entry(id) { + Entry::Occupied(mut occupied) => { + let entry = occupied.get_mut(); + if let Some(score) = candidate.scores.fts { + let existing = entry.scores.fts.unwrap_or(f32::MIN); + if score > existing { + entry.scores.fts = Some(score); + } + } + entry.item = candidate.item; + entry.fused += contribution; + } + Entry::Vacant(vacant) => { + let mut scored = Scored::new(candidate.item); + if let Some(score) = candidate.scores.fts { + scored.scores.fts = Some(score); + } + scored.fused = contribution; + vacant.insert(scored); } } - entry.item = candidate.item; - let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from); - entry.fused += fts_weight / (k + rank_f32 + 1.0); } }