diff --git a/common/src/storage/types/knowledge_entity.rs b/common/src/storage/types/knowledge_entity.rs index 1d9c48c..4509b5a 100644 --- a/common/src/storage/types/knowledge_entity.rs +++ b/common/src/storage/types/knowledge_entity.rs @@ -44,36 +44,12 @@ impl From for KnowledgeEntityType { } } -#[derive(Debug, Deserialize, Serialize)] +/// Search result including hydrated entity. +#[allow(clippy::module_name_repetitions)] +#[derive(Debug, Clone, PartialEq)] pub struct KnowledgeEntitySearchResult { - #[serde(deserialize_with = "deserialize_flexible_id")] - pub id: String, - #[serde( - serialize_with = "serialize_datetime", - deserialize_with = "deserialize_datetime", - default - )] - pub created_at: DateTime, - #[serde( - serialize_with = "serialize_datetime", - deserialize_with = "deserialize_datetime", - default - )] - pub updated_at: DateTime, - - pub source_id: String, - pub name: String, - pub description: String, - pub entity_type: KnowledgeEntityType, - #[serde(default)] - pub metadata: Option, - pub user_id: String, - + pub entity: KnowledgeEntity, pub score: f32, - #[serde(default)] - pub highlighted_name: Option, - #[serde(default)] - pub highlighted_description: Option, } stored_object!(KnowledgeEntity, "knowledge_entity", { @@ -85,13 +61,6 @@ stored_object!(KnowledgeEntity, "knowledge_entity", { user_id: String }); -/// Vector search result including hydrated entity. -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] -pub struct KnowledgeEntityVectorResult { - pub entity: KnowledgeEntity, - pub score: f32, -} - impl KnowledgeEntity { #[must_use] pub fn new( @@ -116,12 +85,33 @@ impl KnowledgeEntity { } } - pub async fn search( + /// Full-text search over knowledge entities using the BM25 FTS index. + pub async fn fts_search( + take: usize, + terms: &str, db: &SurrealDbClient, - search_terms: &str, user_id: &str, - limit: usize, ) -> Result, AppError> { + #[derive(Deserialize)] + struct Row { + #[serde(deserialize_with = "deserialize_flexible_id")] + id: String, + #[serde(deserialize_with = "deserialize_datetime")] + created_at: DateTime, + #[serde(deserialize_with = "deserialize_datetime")] + updated_at: DateTime, + source_id: String, + name: String, + description: String, + entity_type: KnowledgeEntityType, + #[serde(default)] + metadata: Option, + user_id: String, + score: f32, + } + + let limit = i64::try_from(take).unwrap_or(i64::MAX); + let sql = r#" SELECT id, @@ -133,8 +123,6 @@ impl KnowledgeEntity { entity_type, metadata, user_id, - search::highlight('', '', 0) AS highlighted_name, - search::highlight('', '', 1) AS highlighted_description, ( IF search::score(0) != NONE THEN search::score(0) ELSE 0 END + IF search::score(1) != NONE THEN search::score(1) ELSE 0 END @@ -150,14 +138,32 @@ impl KnowledgeEntity { LIMIT $limit; "#; - Ok(db + let rows: Vec = db .client .query(sql) - .bind(("terms", search_terms.to_owned())) + .bind(("terms", terms.to_owned())) .bind(("user_id", user_id.to_owned())) .bind(("limit", limit)) .await? - .take(0)?) + .take(0)?; + + Ok(rows + .into_iter() + .map(|row| KnowledgeEntitySearchResult { + entity: KnowledgeEntity { + id: row.id, + created_at: row.created_at, + updated_at: row.updated_at, + source_id: row.source_id, + name: row.name, + description: row.description, + entity_type: row.entity_type, + metadata: row.metadata, + user_id: row.user_id, + }, + score: row.score, + }) + .collect()) } /// Fetch all knowledge entities owned by any of the provided source ids for a user. @@ -260,7 +266,7 @@ impl KnowledgeEntity { query_embedding: Vec, db: &SurrealDbClient, user_id: &str, - ) -> Result, AppError> { + ) -> Result, AppError> { #[derive(Deserialize)] struct Row { entity_id: Option, @@ -297,7 +303,7 @@ impl KnowledgeEntity { Ok(rows .into_iter() .filter_map(|r| { - r.entity_id.map(|entity| KnowledgeEntityVectorResult { + r.entity_id.map(|entity| KnowledgeEntitySearchResult { entity, score: r.score, }) @@ -605,12 +611,35 @@ impl KnowledgeEntity { mod tests { #![allow(clippy::expect_used, clippy::must_use_candidate)] use super::*; + use crate::storage::indexes::rebuild; use crate::storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding; use crate::test_utils::configure_embedding_dimension; use anyhow::{self, Context}; - use serde_json::json; use uuid::Uuid; + async fn ensure_entity_fts_indexes(db: &SurrealDbClient) -> anyhow::Result<()> { + let snowball_sql = r#" + DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii, snowball(english); + DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25; + DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25; + "#; + + if let Err(err) = db.client.query(snowball_sql).await { + let fallback_sql = r#" + DEFINE ANALYZER OVERWRITE app_en_fts_analyzer TOKENIZERS class, punct FILTERS lowercase, ascii; + DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity FIELDS name SEARCH ANALYZER app_en_fts_analyzer BM25; + DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity FIELDS description SEARCH ANALYZER app_en_fts_analyzer BM25; + "#; + + db.client + .query(fallback_sql) + .await + .with_context(|| format!("define entity fts index fallback: {err}"))?; + } + Ok(()) + } + use serde_json::json; + #[tokio::test] async fn test_knowledge_entity_creation() -> anyhow::Result<()> { let source_id = "source123".to_string(); @@ -1106,4 +1135,134 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_fts_search_returns_empty_when_no_entities() -> anyhow::Result<()> { + let namespace = "fts_entity_ns_empty"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations() + .await + .with_context(|| "migrations".to_string())?; + ensure_entity_fts_indexes(&db).await?; + rebuild(&db) + .await + .with_context(|| "rebuild indexes".to_string())?; + + let results = KnowledgeEntity::fts_search(5, "hello", &db, "user") + .await + .with_context(|| "fts search".to_string())?; + + assert!(results.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn test_fts_search_single_result() -> anyhow::Result<()> { + let namespace = "fts_entity_ns_single"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations() + .await + .with_context(|| "migrations".to_string())?; + ensure_entity_fts_indexes(&db).await?; + + let user_id = "fts_user"; + let entity = KnowledgeEntity::new( + "fts_src".to_string(), + "cucumber".to_string(), + "cucumbers are best".to_string(), + KnowledgeEntityType::Document, + None, + user_id.to_string(), + ); + db.store_item(entity.clone()) + .await + .with_context(|| "store entity".to_string())?; + rebuild(&db) + .await + .with_context(|| "rebuild indexes".to_string())?; + + let results = KnowledgeEntity::fts_search(3, "cucumber", &db, user_id) + .await + .with_context(|| "fts search".to_string())?; + + assert_eq!(results.len(), 1); + let r0 = results.first().context("expected first result")?; + assert_eq!(r0.entity.id, entity.id); + assert!(r0.score.is_finite(), "expected a finite FTS score"); + Ok(()) + } + + #[tokio::test] + async fn test_fts_search_orders_by_score_and_filters_user() -> anyhow::Result<()> { + let namespace = "fts_entity_ns_order"; + let database = &Uuid::new_v4().to_string(); + let db = SurrealDbClient::memory(namespace, database) + .await + .with_context(|| "Failed to start in-memory surrealdb".to_string())?; + db.apply_migrations() + .await + .with_context(|| "migrations".to_string())?; + ensure_entity_fts_indexes(&db).await?; + + let user_id = "fts_user_order"; + let high_score_entity = KnowledgeEntity::new( + "src1".to_string(), + "apple apple apple pie".to_string(), + "dessert recipe".to_string(), + KnowledgeEntityType::Document, + None, + user_id.to_string(), + ); + let low_score_entity = KnowledgeEntity::new( + "src2".to_string(), + "apple tart".to_string(), + "light dessert".to_string(), + KnowledgeEntityType::Document, + None, + user_id.to_string(), + ); + let other_user_entity = KnowledgeEntity::new( + "src3".to_string(), + "apple orchard".to_string(), + "farming guide".to_string(), + KnowledgeEntityType::Document, + None, + "other_user".to_string(), + ); + + db.store_item(high_score_entity.clone()) + .await + .with_context(|| "store high score entity".to_string())?; + db.store_item(low_score_entity.clone()) + .await + .with_context(|| "store low score entity".to_string())?; + db.store_item(other_user_entity) + .await + .with_context(|| "store other user entity".to_string())?; + rebuild(&db) + .await + .with_context(|| "rebuild indexes".to_string())?; + + let results = KnowledgeEntity::fts_search(3, "apple", &db, user_id) + .await + .with_context(|| "fts search".to_string())?; + + assert_eq!(results.len(), 2); + let ids: Vec<_> = results.iter().map(|r| r.entity.id.as_str()).collect(); + assert!( + ids.contains(&high_score_entity.id.as_str()) + && ids.contains(&low_score_entity.id.as_str()), + "expected only the two entities for the same user" + ); + let r0 = results.first().context("expected first result")?; + let r1 = results.get(1).context("expected second result")?; + assert!(r0.score >= r1.score); + Ok(()) + } } diff --git a/docs/features.md b/docs/features.md index 9af4163..bd4c8a6 100644 --- a/docs/features.md +++ b/docs/features.md @@ -27,14 +27,16 @@ The D3-based graph visualization shows entities as nodes and relationships as ed ## Hybrid Retrieval -Minne uses chunk-first hybrid retrieval over the knowledge base: +Minne uses hybrid retrieval over the knowledge base: -- **Vector similarity** — Semantic matching via embeddings over text chunks -- **Full-text search** — Keyword matching with BM25 over the same chunk index +- **Vector similarity** — Semantic matching via embeddings +- **Full-text search** — Keyword matching with BM25 -The two ranked candidate lists are merged with Reciprocal Rank Fusion (RRF). When a caller needs knowledge entities (search, ingestion linking, relationship suggestion), entities are derived from the top retrieved chunks grouped by `source_id`. +For **content search** (chat, global search, ingestion linking), retrieval is chunk-first: vector and FTS run over `text_chunk` rows, merged with Reciprocal Rank Fusion (RRF). When entities are needed, they are derived from the top retrieved chunks grouped by `source_id`. -Optional **reranking** can rescore the fused chunk list with a cross-encoder model; see below. +For **relationship suggestions** when creating an entity, retrieval is entity-first: vector and FTS run directly over `knowledge_entity` name/description and embedding indexes, then merged with the same RRF approach. + +Optional **reranking** can rescore fused chunk lists with a cross-encoder model; see below. ## Reranking (Optional) diff --git a/html-router/src/routes/knowledge/handlers.rs b/html-router/src/routes/knowledge/handlers.rs index 585b77d..0a0578a 100644 --- a/html-router/src/routes/knowledge/handlers.rs +++ b/html-router/src/routes/knowledge/handlers.rs @@ -16,14 +16,19 @@ use serde::{ use common::{ error::AppError, - storage::types::{ - knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, - knowledge_relationship::KnowledgeRelationship, - user::User, + storage::{ + db::SurrealDbClient, + types::{ + knowledge_entity::{KnowledgeEntity, KnowledgeEntityType}, + knowledge_relationship::KnowledgeRelationship, + user::User, + }, }, - utils::embedding::generate_embedding_with_provider, + utils::embedding::{generate_embedding_with_provider, EmbeddingProvider}, +}; +use retrieval_pipeline::{ + normalize_fts_terms, reciprocal_rank_fusion, RetrievalTuning, RrfConfig, Scored, }; -use retrieval_pipeline; use tracing::debug; use uuid::Uuid; @@ -43,7 +48,6 @@ const KNOWLEDGE_ENTITIES_PER_PAGE: usize = 12; const RELATIONSHIP_TYPE_OPTIONS: &[&str] = &["RelatedTo", "RelevantTo", "SimilarTo", "References"]; const DEFAULT_RELATIONSHIP_TYPE: &str = "RelatedTo"; const MAX_RELATIONSHIP_SUGGESTIONS: usize = 10; -const SUGGESTION_MIN_SCORE: f32 = 0.5; const GRAPH_REFRESH_TRIGGER: &str = r#"{"knowledge-graph-refresh":true}"#; const RELATIONSHIP_TYPE_ALIASES: &[(&str, &str)] = &[("relatesto", "RelatedTo")]; @@ -279,38 +283,30 @@ pub async fn suggest_knowledge_relationships( } if !query_parts.is_empty() { - let query = query_parts.join(" "); - let rerank_lease = match state.reranker_pool.as_ref() { - Some(pool) => pool.checkout().await, - None => None, - }; + let name = form.name.as_deref().unwrap_or("").trim(); + let description = form.description.as_deref().unwrap_or("").trim(); + let entity_type = form.entity_type.as_deref().map_or( + KnowledgeEntityType::Document, + |value| KnowledgeEntityType::from(value.to_string()), + ); - let config = retrieval_pipeline::RetrievalConfig::with_entities(); - if let Ok(retrieval_pipeline::RetrievalOutput::WithEntities { entities, .. }) = - retrieval_pipeline::retrieve( - &state.db, - &state.openai_client, - Some(&*state.embedding_provider), - &query, - &user.id, - config, - rerank_lease, - ) - .await - { - for retrieval_pipeline::RetrievedEntity { entity, score, .. } in entities { - if suggestion_scores.len() >= MAX_RELATIONSHIP_SUGGESTIONS { - break; - } - if score.is_nan() || score < SUGGESTION_MIN_SCORE { - continue; - } - if !entity_lookup.contains_key(&entity.id) { - continue; - } - suggestion_scores.insert(entity.id.clone(), score); - selected_ids.insert(entity.id.clone()); - } + let suggested = suggest_related_entities( + &state.db, + &state.embedding_provider, + &user.id, + DraftEntityQuery { + name, + description, + entity_type, + search_terms: &query_parts.join(" "), + }, + &entity_lookup, + ) + .await?; + + for (id, score) in suggested { + selected_ids.insert(id.clone()); + suggestion_scores.insert(id, score); } } @@ -359,6 +355,90 @@ pub struct RelationshipTableRow { relationship_type_label: String, } +struct DraftEntityQuery<'a> { + name: &'a str, + description: &'a str, + entity_type: KnowledgeEntityType, + search_terms: &'a str, +} + +async fn suggest_related_entities( + db: &SurrealDbClient, + embedding_provider: &EmbeddingProvider, + user_id: &str, + draft: DraftEntityQuery<'_>, + entity_lookup: &HashMap, +) -> Result, AppError> { + let embedding_input = format!( + "name: {}, description: {}, type: {:?}", + draft.name, draft.description, draft.entity_type + ); + let embedding = + generate_embedding_with_provider(embedding_provider, &embedding_input).await?; + + let take = MAX_RELATIONSHIP_SUGGESTIONS * 2; + let tuning = RetrievalTuning::default(); + let (fts_query, fts_token_count) = normalize_fts_terms(draft.search_terms); + let fts_enabled = tuning.flags.chunk_rrf_use_fts() && !fts_query.is_empty(); + let suggestion_min_rrf_score = 1.0 / (tuning.chunk_rrf_k + 1.0); + + let (vector_rows, fts_rows) = tokio::try_join!( + KnowledgeEntity::vector_search(take, embedding, db, user_id), + async { + if fts_enabled { + KnowledgeEntity::fts_search(take, &fts_query, db, user_id).await + } else { + Ok(Vec::new()) + } + } + )?; + + let fts_candidates = fts_rows.len(); + + let vector_scored: Vec> = vector_rows + .into_iter() + .map(|row| Scored::new(row.entity).with_vector_score(row.score)) + .collect(); + + let fts_scored: Vec> = fts_rows + .into_iter() + .map(|row| Scored::new(row.entity).with_fts_score(row.score)) + .collect(); + + let mut fts_weight = tuning.chunk_rrf_fts_weight; + if fts_enabled && fts_token_count > 0 && fts_token_count <= 3 { + fts_weight *= 1.5; + } + + let fused = reciprocal_rank_fusion( + vector_scored, + fts_scored, + RrfConfig { + k: tuning.chunk_rrf_k, + vector_weight: tuning.chunk_rrf_vector_weight, + fts_weight, + use_vector: tuning.flags.chunk_rrf_use_vector(), + use_fts: tuning.flags.chunk_rrf_use_fts() && fts_candidates > 0, + }, + ); + + let mut suggestions = HashMap::new(); + for scored in fused { + if suggestions.len() >= MAX_RELATIONSHIP_SUGGESTIONS { + break; + } + if scored.fused.is_nan() || scored.fused < suggestion_min_rrf_score { + continue; + } + if !entity_lookup.contains_key(&scored.item.id) { + continue; + } + suggestions.insert(scored.item.id, scored.fused); + } + + Ok(suggestions) +} + fn build_relationship_options( entities: Vec, selected_ids: &HashSet, @@ -618,6 +698,7 @@ impl<'de> Deserialize<'de> for CreateKnowledgeEntityParams { pub struct SuggestRelationshipsParams { pub name: Option, pub description: Option, + pub entity_type: Option, pub relationship_type: Option, pub relationship_ids: Vec, } @@ -653,6 +734,7 @@ impl<'de> Deserialize<'de> for SuggestRelationshipsParams { { let mut name: Option = None; let mut description: Option = None; + let mut entity_type: Option = None; let mut relationship_type: Option = None; let mut relationship_ids: Vec = Vec::new(); @@ -687,7 +769,13 @@ impl<'de> Deserialize<'de> for SuggestRelationshipsParams { } } Field::EntityType => { - map.next_value::()?; + let value: String = map.next_value()?; + let trimmed = value.trim(); + if trimmed.is_empty() { + entity_type = None; + } else { + entity_type = Some(trimmed.to_owned()); + } } Field::RelationshipIds => { let value: String = map.next_value()?; @@ -702,6 +790,7 @@ impl<'de> Deserialize<'de> for SuggestRelationshipsParams { Ok(SuggestRelationshipsParams { name, description, + entity_type, relationship_type, relationship_ids, }) diff --git a/retrieval-pipeline/src/lib.rs b/retrieval-pipeline/src/lib.rs index 6332d18..3bc9068 100644 --- a/retrieval-pipeline/src/lib.rs +++ b/retrieval-pipeline/src/lib.rs @@ -1,9 +1,9 @@ pub mod answer_retrieval; pub mod pipeline; +pub mod query; pub mod reranking; - -pub(crate) mod scoring; +pub mod scoring; use common::{ error::AppError, @@ -29,9 +29,11 @@ pub enum RetrievalOutput { } pub use pipeline::{ - retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, StageKind, - StageTimings, + retrieved_entities_to_json, Diagnostics, RetrievalConfig, RetrievalParams, RetrievalTuning, + StageKind, StageTimings, }; +pub use query::normalize_fts_terms; +pub use scoring::{reciprocal_rank_fusion, RrfConfig, Scored}; /// Round a score to three decimal places for JSON output. pub(crate) fn round_score(value: f32) -> f64 { diff --git a/retrieval-pipeline/src/pipeline/config.rs b/retrieval-pipeline/src/pipeline/config.rs index b047b1c..b6714d0 100644 --- a/retrieval-pipeline/src/pipeline/config.rs +++ b/retrieval-pipeline/src/pipeline/config.rs @@ -117,8 +117,8 @@ impl Default for RetrievalTuning { /// Per-request retrieval configuration. /// /// The pipeline always performs chunk-first hybrid retrieval. Set `resolve_entities` -/// when a caller additionally needs the `KnowledgeEntity` rows that own the retrieved -/// chunks (search, ingestion linking, relationship suggestion). +/// when a caller additionally needs the `KnowledgeEntity` rows that own retrieved +/// chunks (search, ingestion linking). #[derive(Debug, Clone, Default)] pub struct RetrievalConfig { pub tuning: RetrievalTuning, diff --git a/retrieval-pipeline/src/pipeline/context.rs b/retrieval-pipeline/src/pipeline/context.rs index a234af8..3111d5e 100644 --- a/retrieval-pipeline/src/pipeline/context.rs +++ b/retrieval-pipeline/src/pipeline/context.rs @@ -5,7 +5,9 @@ use common::{ utils::embedding::EmbeddingProvider, }; -use crate::{reranking::RerankerLease, scoring::Scored, RetrievedChunk, RetrievedEntity}; +use crate::scoring::Scored; + +use crate::{reranking::RerankerLease, RetrievedChunk, RetrievedEntity}; use super::{ config::RetrievalConfig, diff --git a/retrieval-pipeline/src/pipeline/mod.rs b/retrieval-pipeline/src/pipeline/mod.rs index 238d0d5..1761099 100644 --- a/retrieval-pipeline/src/pipeline/mod.rs +++ b/retrieval-pipeline/src/pipeline/mod.rs @@ -3,7 +3,7 @@ mod context; mod diagnostics; mod stages; -pub use config::RetrievalConfig; +pub use config::{RetrievalConfig, RetrievalTuning}; pub use diagnostics::Diagnostics; use crate::{round_score, RetrievalOutput, RetrievedEntity}; diff --git a/retrieval-pipeline/src/pipeline/stages.rs b/retrieval-pipeline/src/pipeline/stages.rs index 7de0bad..a9b107c 100644 --- a/retrieval-pipeline/src/pipeline/stages.rs +++ b/retrieval-pipeline/src/pipeline/stages.rs @@ -9,9 +9,8 @@ use std::collections::HashMap; use tracing::{debug, instrument, warn}; use crate::{ - scoring::{ - clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored, - }, + query::normalize_fts_terms, + scoring::{clamp_unit, min_max_normalize, reciprocal_rank_fusion, RrfConfig, Scored}, RetrievedChunk, RetrievedEntity, }; @@ -115,7 +114,7 @@ pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone(); let tuning = &ctx.config.tuning; let fts_take = tuning.chunk_fts_take; - let (fts_query, fts_token_count) = normalize_fts_query(&ctx.input_text); + let (fts_query, fts_token_count) = normalize_fts_terms(&ctx.input_text); let fts_enabled = tuning.flags.chunk_rrf_use_fts() && fts_take > 0 && !fts_query.is_empty(); let (vector_rows, fts_rows) = tokio::try_join!( @@ -333,26 +332,6 @@ where items.iter().take(SCORE_SAMPLE_LIMIT).map(extractor).collect() } -fn normalize_fts_query(input: &str) -> (String, usize) { - const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"]; - let mut cleaned = String::with_capacity(input.len()); - for ch in input.chars() { - if ch.is_alphanumeric() { - cleaned.extend(ch.to_lowercase()); - } else if ch.is_whitespace() { - cleaned.push(' '); - } - } - let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3)); - for token in cleaned.split_whitespace() { - if !STOPWORDS.contains(&token) && !token.is_empty() { - tokens.push(token.to_string()); - } - } - let normalized = tokens.join(" "); - (normalized, tokens.len()) -} - fn build_chunk_rerank_documents(chunks: &[Scored], max_chunks: usize) -> Vec { chunks .iter() diff --git a/retrieval-pipeline/src/query.rs b/retrieval-pipeline/src/query.rs new file mode 100644 index 0000000..9229f3a --- /dev/null +++ b/retrieval-pipeline/src/query.rs @@ -0,0 +1,39 @@ +/// Normalize raw input into FTS-friendly terms and return the token count. +pub fn normalize_fts_terms(input: &str) -> (String, usize) { + const STOPWORDS: &[&str] = &["the", "a", "an", "of", "in", "on", "and", "or", "to", "for"]; + let mut cleaned = String::with_capacity(input.len()); + for ch in input.chars() { + if ch.is_alphanumeric() { + cleaned.extend(ch.to_lowercase()); + } else if ch.is_whitespace() { + cleaned.push(' '); + } + } + let mut tokens = Vec::with_capacity(cleaned.len().div_ceil(3)); + for token in cleaned.split_whitespace() { + if !STOPWORDS.contains(&token) && !token.is_empty() { + tokens.push(token.to_string()); + } + } + let normalized = tokens.join(" "); + (normalized, tokens.len()) +} + +#[cfg(test)] +mod tests { + use super::normalize_fts_terms; + + #[test] + fn strips_stopwords_and_lowercases() { + let (query, count) = normalize_fts_terms("The Cucumber and Tomatoes"); + assert_eq!(query, "cucumber tomatoes"); + assert_eq!(count, 2); + } + + #[test] + fn returns_empty_for_stopwords_only() { + let (query, count) = normalize_fts_terms("the and or"); + assert!(query.is_empty()); + assert_eq!(count, 0); + } +}