mirror of
https://github.com/perstarkse/minne.git
synced 2026-07-01 18:41:37 +02:00
fix: replaced several instances if cloning, reduced allocations
This commit is contained in:
@@ -258,7 +258,7 @@ impl KnowledgeEntity {
|
|||||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||||
pub async fn vector_search(
|
pub async fn vector_search(
|
||||||
take: usize,
|
take: usize,
|
||||||
query_embedding: Vec<f32>,
|
query_embedding: &[f32],
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||||
@@ -286,7 +286,7 @@ impl KnowledgeEntity {
|
|||||||
|
|
||||||
let mut response = db
|
let mut response = db
|
||||||
.query(&sql)
|
.query(&sql)
|
||||||
.bind(("embedding", query_embedding))
|
.bind(("embedding", query_embedding.to_vec()))
|
||||||
.bind(("user_id", user_id.to_string()))
|
.bind(("user_id", user_id.to_string()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?;
|
.map_err(AppError::from)?;
|
||||||
@@ -408,7 +408,7 @@ impl KnowledgeEntity {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let embeddings = provider.embed_batch(inputs).await?;
|
let embeddings = provider.embed_batch(&inputs).await?;
|
||||||
if embeddings.len() != batch.len() {
|
if embeddings.len() != batch.len() {
|
||||||
return Err(AppError::internal(format!(
|
return Err(AppError::internal(format!(
|
||||||
"embedding batch returned {} vectors for {} entities",
|
"embedding batch returned {} vectors for {} entities",
|
||||||
@@ -817,7 +817,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("Failed to redefine index length");
|
.expect("Failed to redefine index length");
|
||||||
|
|
||||||
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
let results = KnowledgeEntity::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||||
.await
|
.await
|
||||||
.expect("vector search");
|
.expect("vector search");
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
@@ -878,7 +878,7 @@ mod tests {
|
|||||||
.with_context(|| "fetch embedding".to_string())?;
|
.with_context(|| "fetch embedding".to_string())?;
|
||||||
assert!(fetched_emb.is_some());
|
assert!(fetched_emb.is_some());
|
||||||
|
|
||||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
let results = KnowledgeEntity::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "vector search".to_string())?;
|
.with_context(|| "vector search".to_string())?;
|
||||||
|
|
||||||
@@ -965,7 +965,7 @@ mod tests {
|
|||||||
.with_context(|| "get embedding e2".to_string())?
|
.with_context(|| "get embedding e2".to_string())?
|
||||||
.is_some());
|
.is_some());
|
||||||
|
|
||||||
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
let results = KnowledgeEntity::vector_search(2, &[0.0, 1.0, 0.0], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "vector search".to_string())?;
|
.with_context(|| "vector search".to_string())?;
|
||||||
|
|
||||||
@@ -1030,7 +1030,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.with_context(|| "delete entity".to_string())?;
|
.with_context(|| "delete entity".to_string())?;
|
||||||
|
|
||||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
let results = KnowledgeEntity::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "search should succeed even with orphans".to_string())?;
|
.with_context(|| "search should succeed even with orphans".to_string())?;
|
||||||
|
|
||||||
|
|||||||
@@ -42,12 +42,24 @@ impl KnowledgeRelationship {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
pub async fn store_relationship(self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||||
User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client)
|
User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client)
|
||||||
.await?;
|
.await?;
|
||||||
User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client)
|
User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let Self {
|
||||||
|
id,
|
||||||
|
in_,
|
||||||
|
out,
|
||||||
|
metadata:
|
||||||
|
RelationshipMetadata {
|
||||||
|
user_id,
|
||||||
|
source_id,
|
||||||
|
relationship_type,
|
||||||
|
},
|
||||||
|
} = self;
|
||||||
|
|
||||||
db_client
|
db_client
|
||||||
.client
|
.client
|
||||||
.query(
|
.query(
|
||||||
@@ -62,12 +74,12 @@ impl KnowledgeRelationship {
|
|||||||
metadata.relationship_type = $relationship_type;
|
metadata.relationship_type = $relationship_type;
|
||||||
COMMIT TRANSACTION;"#,
|
COMMIT TRANSACTION;"#,
|
||||||
)
|
)
|
||||||
.bind(("rel_id", self.id.clone()))
|
.bind(("rel_id", id))
|
||||||
.bind(("in_id", self.in_.clone()))
|
.bind(("in_id", in_))
|
||||||
.bind(("out_id", self.out.clone()))
|
.bind(("out_id", out))
|
||||||
.bind(("user_id", self.metadata.user_id.clone()))
|
.bind(("user_id", user_id))
|
||||||
.bind(("source_id", self.metadata.source_id.clone()))
|
.bind(("source_id", source_id))
|
||||||
.bind(("relationship_type", self.metadata.relationship_type.clone()))
|
.bind(("relationship_type", relationship_type))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?
|
.map_err(AppError::from)?
|
||||||
.check()
|
.check()
|
||||||
@@ -230,13 +242,14 @@ mod tests {
|
|||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
relationship_type,
|
relationship_type,
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to store relationship".to_string())?;
|
.with_context(|| "Failed to store relationship".to_string())?;
|
||||||
|
|
||||||
let persisted = get_relationship_by_id(&relationship.id, &db)
|
let persisted = get_relationship_by_id(&relationship_id, &db)
|
||||||
.await
|
.await
|
||||||
.expect("Relationship should be retrievable by id");
|
.expect("Relationship should be retrievable by id");
|
||||||
assert_eq!(persisted.in_, entity1_id);
|
assert_eq!(persisted.in_, entity1_id);
|
||||||
@@ -296,6 +309,7 @@ mod tests {
|
|||||||
"source123'; DELETE FROM relates_to; --".to_string(),
|
"source123'; DELETE FROM relates_to; --".to_string(),
|
||||||
"references'; UPDATE user SET admin = true; --".to_string(),
|
"references'; UPDATE user SET admin = true; --".to_string(),
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
@@ -305,7 +319,7 @@ mod tests {
|
|||||||
let mut res = db
|
let mut res = db
|
||||||
.client
|
.client
|
||||||
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
||||||
.bind(("id", relationship.id.clone()))
|
.bind(("id", relationship_id))
|
||||||
.await
|
.await
|
||||||
.expect("query relationship by id failed");
|
.expect("query relationship by id failed");
|
||||||
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
||||||
@@ -338,6 +352,7 @@ mod tests {
|
|||||||
source_id.clone(),
|
source_id.clone(),
|
||||||
relationship_type,
|
relationship_type,
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
@@ -357,7 +372,7 @@ mod tests {
|
|||||||
"Relationship should exist before deletion"
|
"Relationship should exist before deletion"
|
||||||
);
|
);
|
||||||
|
|
||||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, user_id, &db)
|
KnowledgeRelationship::delete_relationship_by_id(&relationship_id, user_id, &db)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
||||||
|
|
||||||
@@ -391,6 +406,7 @@ mod tests {
|
|||||||
source_id,
|
source_id,
|
||||||
"references".to_string(),
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
|
let relationship_id = relationship.id.clone();
|
||||||
|
|
||||||
relationship
|
relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
@@ -409,7 +425,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let result = KnowledgeRelationship::delete_relationship_by_id(
|
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||||
&relationship.id,
|
&relationship_id,
|
||||||
"different-user",
|
"different-user",
|
||||||
&db,
|
&db,
|
||||||
)
|
)
|
||||||
@@ -472,6 +488,9 @@ mod tests {
|
|||||||
different_source_id.clone(),
|
different_source_id.clone(),
|
||||||
"mentions".to_string(),
|
"mentions".to_string(),
|
||||||
);
|
);
|
||||||
|
let relationship1_id = relationship1.id.clone();
|
||||||
|
let relationship2_id = relationship2.id.clone();
|
||||||
|
let different_relationship_id = different_relationship.id.clone();
|
||||||
|
|
||||||
relationship1
|
relationship1
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
@@ -508,9 +527,9 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
||||||
|
|
||||||
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
|
let result1 = get_relationship_by_id(&relationship1_id, &db).await;
|
||||||
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
|
let result2 = get_relationship_by_id(&relationship2_id, &db).await;
|
||||||
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
|
let different_result = get_relationship_by_id(&different_relationship_id, &db).await;
|
||||||
|
|
||||||
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
||||||
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
||||||
@@ -548,6 +567,8 @@ mod tests {
|
|||||||
shared_source.to_string(),
|
shared_source.to_string(),
|
||||||
"references".to_string(),
|
"references".to_string(),
|
||||||
);
|
);
|
||||||
|
let rel_a_id = rel_a.id.clone();
|
||||||
|
let rel_b_id = rel_b.id.clone();
|
||||||
|
|
||||||
rel_a.store_relationship(&db).await?;
|
rel_a.store_relationship(&db).await?;
|
||||||
rel_b.store_relationship(&db).await?;
|
rel_b.store_relationship(&db).await?;
|
||||||
@@ -555,8 +576,8 @@ mod tests {
|
|||||||
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
assert!(get_relationship_by_id(&rel_a.id, &db).await.is_none());
|
assert!(get_relationship_by_id(&rel_a_id, &db).await.is_none());
|
||||||
assert!(get_relationship_by_id(&rel_b.id, &db).await.is_some());
|
assert!(get_relationship_by_id(&rel_b_id, &db).await.is_some());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -586,6 +607,8 @@ mod tests {
|
|||||||
"other_source".to_string(),
|
"other_source".to_string(),
|
||||||
"contains".to_string(),
|
"contains".to_string(),
|
||||||
);
|
);
|
||||||
|
let safe_relationship_id = safe_relationship.id.clone();
|
||||||
|
let other_relationship_id = other_relationship.id.clone();
|
||||||
|
|
||||||
safe_relationship
|
safe_relationship
|
||||||
.store_relationship(&db)
|
.store_relationship(&db)
|
||||||
@@ -604,8 +627,8 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.expect("delete call should succeed");
|
.expect("delete call should succeed");
|
||||||
|
|
||||||
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
|
let remaining_safe = get_relationship_by_id(&safe_relationship_id, &db).await;
|
||||||
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
|
let remaining_other = get_relationship_by_id(&other_relationship_id, &db).await;
|
||||||
|
|
||||||
assert!(remaining_safe.is_some(), "Safe relationship should remain");
|
assert!(remaining_safe.is_some(), "Safe relationship should remain");
|
||||||
assert!(
|
assert!(
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ impl TextChunk {
|
|||||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
||||||
pub async fn vector_search(
|
pub async fn vector_search(
|
||||||
take: usize,
|
take: usize,
|
||||||
query_embedding: Vec<f32>,
|
query_embedding: &[f32],
|
||||||
db: &SurrealDbClient,
|
db: &SurrealDbClient,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||||
@@ -137,7 +137,7 @@ impl TextChunk {
|
|||||||
|
|
||||||
let mut response = db
|
let mut response = db
|
||||||
.query(&sql)
|
.query(&sql)
|
||||||
.bind(("embedding", query_embedding))
|
.bind(("embedding", query_embedding.to_vec()))
|
||||||
.bind(("user_id", user_id.to_string()))
|
.bind(("user_id", user_id.to_string()))
|
||||||
.await
|
.await
|
||||||
.map_err(AppError::from)?;
|
.map_err(AppError::from)?;
|
||||||
@@ -273,7 +273,7 @@ impl TextChunk {
|
|||||||
let mut processed = 0usize;
|
let mut processed = 0usize;
|
||||||
for batch in all_chunks.chunks(RE_EMBED_BATCH_SIZE) {
|
for batch in all_chunks.chunks(RE_EMBED_BATCH_SIZE) {
|
||||||
let inputs: Vec<String> = batch.iter().map(|chunk| chunk.chunk.clone()).collect();
|
let inputs: Vec<String> = batch.iter().map(|chunk| chunk.chunk.clone()).collect();
|
||||||
let embeddings = provider.embed_batch(inputs).await?;
|
let embeddings = provider.embed_batch(&inputs).await?;
|
||||||
if embeddings.len() != batch.len() {
|
if embeddings.len() != batch.len() {
|
||||||
return Err(AppError::internal(format!(
|
return Err(AppError::internal(format!(
|
||||||
"embedding batch returned {} vectors for {} chunks",
|
"embedding batch returned {} vectors for {} chunks",
|
||||||
@@ -720,7 +720,7 @@ mod tests {
|
|||||||
.with_context(|| "redefine index".to_string())?;
|
.with_context(|| "redefine index".to_string())?;
|
||||||
|
|
||||||
let results: Vec<TextChunkSearchResult> =
|
let results: Vec<TextChunkSearchResult> =
|
||||||
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
TextChunk::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||||
.await
|
.await
|
||||||
.with_context(|| "vector_search".to_string())?;
|
.with_context(|| "vector_search".to_string())?;
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
@@ -756,7 +756,7 @@ mod tests {
|
|||||||
.with_context(|| "store".to_string())?;
|
.with_context(|| "store".to_string())?;
|
||||||
|
|
||||||
let results: Vec<TextChunkSearchResult> =
|
let results: Vec<TextChunkSearchResult> =
|
||||||
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
TextChunk::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "vector_search".to_string())?;
|
.with_context(|| "vector_search".to_string())?;
|
||||||
|
|
||||||
@@ -796,7 +796,7 @@ mod tests {
|
|||||||
.with_context(|| "store chunk2".to_string())?;
|
.with_context(|| "store chunk2".to_string())?;
|
||||||
|
|
||||||
let results: Vec<TextChunkSearchResult> =
|
let results: Vec<TextChunkSearchResult> =
|
||||||
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
TextChunk::vector_search(2, &[0.0, 1.0, 0.0], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "vector_search".to_string())?;
|
.with_context(|| "vector_search".to_string())?;
|
||||||
|
|
||||||
@@ -987,7 +987,7 @@ mod tests {
|
|||||||
.await
|
.await
|
||||||
.with_context(|| "delete chunk".to_string())?;
|
.with_context(|| "delete chunk".to_string())?;
|
||||||
|
|
||||||
let results = TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
let results = TextChunk::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "search should succeed even with orphans".to_string())?;
|
.with_context(|| "search should succeed even with orphans".to_string())?;
|
||||||
|
|
||||||
|
|||||||
@@ -372,17 +372,17 @@ impl EmbeddingProvider {
|
|||||||
///
|
///
|
||||||
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
||||||
/// Returns an empty `Vec` when `texts` is empty.
|
/// Returns an empty `Vec` when `texts` is empty.
|
||||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||||
match &self.inner {
|
match &self.inner {
|
||||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(|text| hashed_embedding(&text, *dimension))
|
.map(|text| hashed_embedding(text, *dimension))
|
||||||
.collect()),
|
.collect()),
|
||||||
EmbeddingInner::FastEmbed { pool, .. } => {
|
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||||
if texts.is_empty() {
|
if texts.is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
run_fastembed(pool, texts).await
|
run_fastembed(pool, texts.to_vec()).await
|
||||||
}
|
}
|
||||||
EmbeddingInner::OpenAI {
|
EmbeddingInner::OpenAI {
|
||||||
client,
|
client,
|
||||||
@@ -395,7 +395,7 @@ impl EmbeddingProvider {
|
|||||||
|
|
||||||
let request = CreateEmbeddingRequestArgs::default()
|
let request = CreateEmbeddingRequestArgs::default()
|
||||||
.model(model.clone())
|
.model(model.clone())
|
||||||
.input(texts)
|
.input(texts.to_vec())
|
||||||
.dimensions(*dimensions)
|
.dimensions(*dimensions)
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ impl EvaluationCandidate {
|
|||||||
entity_description: Some(entity.entity.description.clone()),
|
entity_description: Some(entity.entity.description.clone()),
|
||||||
entity_category,
|
entity_category,
|
||||||
score: entity.score,
|
score: entity.score,
|
||||||
chunks: entity.chunks,
|
chunks: entity.chunks.as_ref().clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ use crate::{
|
|||||||
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
template_with_headers, TemplateResponse, TemplateResult, ResponseResult,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
utils::pagination::{paginate_items, Pagination},
|
utils::pagination::{paginate_items, paginate_slice, Pagination},
|
||||||
};
|
};
|
||||||
use url::form_urlencoded;
|
use url::form_urlencoded;
|
||||||
|
|
||||||
@@ -196,16 +196,18 @@ pub async fn create_knowledge_entity(
|
|||||||
let source_id = format!("manual::{}", Uuid::new_v4());
|
let source_id = format!("manual::{}", Uuid::new_v4());
|
||||||
let new_entity = KnowledgeEntity::new(
|
let new_entity = KnowledgeEntity::new(
|
||||||
source_id,
|
source_id,
|
||||||
name.clone(),
|
name,
|
||||||
description.clone(),
|
description,
|
||||||
entity_type,
|
entity_type,
|
||||||
None,
|
None,
|
||||||
user.id.clone(),
|
user.id.clone(),
|
||||||
);
|
);
|
||||||
|
let new_entity_id = new_entity.id.clone();
|
||||||
|
|
||||||
KnowledgeEntity::store_with_embedding(new_entity.clone(), embedding, &state.db).await?;
|
KnowledgeEntity::store_with_embedding(new_entity, embedding, &state.db).await?;
|
||||||
|
|
||||||
let relationship_type = relationship_type_or_default(form.relationship_type.as_deref());
|
let relationship_type = relationship_type_or_default(form.relationship_type.as_deref());
|
||||||
|
let user_id = user.id.clone();
|
||||||
|
|
||||||
debug!("form: {:?}", form);
|
debug!("form: {:?}", form);
|
||||||
if !form.relationship_ids.is_empty() {
|
if !form.relationship_ids.is_empty() {
|
||||||
@@ -217,7 +219,7 @@ pub async fn create_knowledge_entity(
|
|||||||
let mut unique_ids: HashSet<String> = HashSet::new();
|
let mut unique_ids: HashSet<String> = HashSet::new();
|
||||||
|
|
||||||
for target_id in form.relationship_ids {
|
for target_id in form.relationship_ids {
|
||||||
if target_id == new_entity.id {
|
if target_id == new_entity_id {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if !valid_ids.contains(&target_id) {
|
if !valid_ids.contains(&target_id) {
|
||||||
@@ -228,10 +230,10 @@ pub async fn create_knowledge_entity(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let relationship = KnowledgeRelationship::new(
|
let relationship = KnowledgeRelationship::new(
|
||||||
new_entity.id.clone(),
|
new_entity_id.clone(),
|
||||||
target_id,
|
target_id,
|
||||||
user.id.clone(),
|
user_id.clone(),
|
||||||
format!("manual::{}", new_entity.id),
|
format!("manual::{new_entity_id}"),
|
||||||
relationship_type.clone(),
|
relationship_type.clone(),
|
||||||
);
|
);
|
||||||
relationship.store_relationship(&state.db).await?;
|
relationship.store_relationship(&state.db).await?;
|
||||||
@@ -385,7 +387,7 @@ async fn suggest_related_entities(
|
|||||||
let suggestion_min_rrf_score = 1.0 / (tuning.chunk_rrf_k + 1.0);
|
let suggestion_min_rrf_score = 1.0 / (tuning.chunk_rrf_k + 1.0);
|
||||||
|
|
||||||
let (vector_rows, fts_rows) = tokio::try_join!(
|
let (vector_rows, fts_rows) = tokio::try_join!(
|
||||||
KnowledgeEntity::vector_search(take, embedding, db, user_id),
|
KnowledgeEntity::vector_search(take, &embedding, db, user_id),
|
||||||
async {
|
async {
|
||||||
if fts_enabled {
|
if fts_enabled {
|
||||||
KnowledgeEntity::fts_search(take, &fts_query, db, user_id).await
|
KnowledgeEntity::fts_search(take, &fts_query, db, user_id).await
|
||||||
@@ -480,10 +482,13 @@ fn build_relationship_options(
|
|||||||
options
|
options
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_relationship_table_data(
|
fn build_relationship_rows(
|
||||||
entities: Vec<KnowledgeEntity>,
|
|
||||||
relationships: Vec<KnowledgeRelationship>,
|
relationships: Vec<KnowledgeRelationship>,
|
||||||
) -> RelationshipTableData {
|
) -> (
|
||||||
|
Vec<RelationshipTableRow>,
|
||||||
|
Vec<String>,
|
||||||
|
String,
|
||||||
|
) {
|
||||||
let relationship_type_options = collect_relationship_type_options(&relationships);
|
let relationship_type_options = collect_relationship_type_options(&relationships);
|
||||||
let mut frequency: HashMap<String, usize> = HashMap::new();
|
let mut frequency: HashMap<String, usize> = HashMap::new();
|
||||||
let relationships = relationships
|
let relationships = relationships
|
||||||
@@ -503,7 +508,25 @@ fn build_relationship_table_data(
|
|||||||
.collect();
|
.collect();
|
||||||
let default_relationship_type = frequency
|
let default_relationship_type = frequency
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.max_by_key(|(_, count)| *count).map_or_else(|| DEFAULT_RELATIONSHIP_TYPE.to_string(), |(label, _)| label);
|
.max_by_key(|(_, count)| *count)
|
||||||
|
.map_or_else(
|
||||||
|
|| DEFAULT_RELATIONSHIP_TYPE.to_string(),
|
||||||
|
|(label, _)| label,
|
||||||
|
);
|
||||||
|
|
||||||
|
(
|
||||||
|
relationships,
|
||||||
|
relationship_type_options,
|
||||||
|
default_relationship_type,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_relationship_table_data(
|
||||||
|
entities: Vec<KnowledgeEntity>,
|
||||||
|
relationships: Vec<KnowledgeRelationship>,
|
||||||
|
) -> RelationshipTableData {
|
||||||
|
let (relationships, relationship_type_options, default_relationship_type) =
|
||||||
|
build_relationship_rows(relationships);
|
||||||
|
|
||||||
RelationshipTableData {
|
RelationshipTableData {
|
||||||
entities,
|
entities,
|
||||||
@@ -532,7 +555,7 @@ async fn build_knowledge_base_data(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let (visible_entities, pagination) =
|
let (visible_entities, pagination) =
|
||||||
paginate_items(entities.clone(), params.page, KNOWLEDGE_ENTITIES_PER_PAGE);
|
paginate_slice(&entities, params.page, KNOWLEDGE_ENTITIES_PER_PAGE);
|
||||||
|
|
||||||
let page_query = {
|
let page_query = {
|
||||||
let mut serializer = form_urlencoded::Serializer::new(String::new());
|
let mut serializer = form_urlencoded::Serializer::new(String::new());
|
||||||
@@ -551,17 +574,15 @@ async fn build_knowledge_base_data(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let relationships = User::get_knowledge_relationships(&user.id, &state.db).await?;
|
let relationships = User::get_knowledge_relationships(&user.id, &state.db).await?;
|
||||||
let entity_id_set: HashSet<String> = entities.iter().map(|e| e.id.clone()).collect();
|
let entity_id_set: HashSet<&str> = entities.iter().map(|e| e.id.as_str()).collect();
|
||||||
let filtered_relationships: Vec<KnowledgeRelationship> = relationships
|
let filtered_relationships: Vec<KnowledgeRelationship> = relationships
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|rel| entity_id_set.contains(&rel.in_) && entity_id_set.contains(&rel.out))
|
.filter(|rel| {
|
||||||
|
entity_id_set.contains(rel.in_.as_str()) && entity_id_set.contains(rel.out.as_str())
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let RelationshipTableData {
|
let (relationships, relationship_type_options, default_relationship_type) =
|
||||||
entities: _,
|
build_relationship_rows(filtered_relationships);
|
||||||
relationships,
|
|
||||||
relationship_type_options,
|
|
||||||
default_relationship_type,
|
|
||||||
} = build_relationship_table_data(entities.clone(), filtered_relationships);
|
|
||||||
|
|
||||||
Ok(KnowledgeBaseData {
|
Ok(KnowledgeBaseData {
|
||||||
entities,
|
entities,
|
||||||
|
|||||||
@@ -57,6 +57,47 @@ impl Pagination {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a cloned page slice and pagination metadata without consuming the source list.
|
||||||
|
pub fn paginate_slice<T: Clone>(
|
||||||
|
items: &[T],
|
||||||
|
requested_page: Option<usize>,
|
||||||
|
per_page: usize,
|
||||||
|
) -> (Vec<T>, Pagination) {
|
||||||
|
let per_page = per_page.max(1);
|
||||||
|
let total_items = items.len();
|
||||||
|
let total_pages = if total_items == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
total_items
|
||||||
|
.saturating_sub(1)
|
||||||
|
.checked_div(per_page)
|
||||||
|
.unwrap_or(0)
|
||||||
|
.saturating_add(1)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut current_page = requested_page.unwrap_or(1);
|
||||||
|
if current_page == 0 {
|
||||||
|
current_page = 1;
|
||||||
|
}
|
||||||
|
if total_pages > 0 {
|
||||||
|
current_page = current_page.min(total_pages);
|
||||||
|
} else {
|
||||||
|
current_page = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let offset = if total_pages == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
per_page.saturating_mul(current_page.saturating_sub(1))
|
||||||
|
};
|
||||||
|
|
||||||
|
let page_items: Vec<T> = items.iter().skip(offset).take(per_page).cloned().collect();
|
||||||
|
let page_len = page_items.len();
|
||||||
|
let pagination = Pagination::new(current_page, per_page, total_items, total_pages, page_len);
|
||||||
|
|
||||||
|
(page_items, pagination)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the items for the requested page along with pagination metadata.
|
/// Returns the items for the requested page along with pagination metadata.
|
||||||
pub fn paginate_items<T>(
|
pub fn paginate_items<T>(
|
||||||
items: Vec<T>,
|
items: Vec<T>,
|
||||||
@@ -96,7 +137,7 @@ pub fn paginate_items<T>(
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::paginate_items;
|
use super::{paginate_items, paginate_slice};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn paginates_basic_case() {
|
fn paginates_basic_case() {
|
||||||
@@ -128,6 +169,16 @@ mod tests {
|
|||||||
assert_eq!(meta.end_index, 0);
|
assert_eq!(meta.end_index, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn paginate_slice_clones_only_page_items() {
|
||||||
|
let items: Vec<_> = (1..=25).collect();
|
||||||
|
let (page, meta) = paginate_slice(&items, Some(2), 10);
|
||||||
|
|
||||||
|
assert_eq!(page, vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20]);
|
||||||
|
assert_eq!(items.len(), 25);
|
||||||
|
assert_eq!(meta.current_page, 2);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn clamps_page_to_bounds() {
|
fn clamps_page_to_bounds() {
|
||||||
let items: Vec<_> = (1..=5).collect();
|
let items: Vec<_> = (1..=5).collect();
|
||||||
|
|||||||
@@ -261,23 +261,24 @@ impl PipelineServices for DefaultPipelineServices {
|
|||||||
|
|
||||||
// Embed all chunks of this document in one batch: a single lock acquisition and one
|
// Embed all chunks of this document in one batch: a single lock acquisition and one
|
||||||
// blocking hop, letting the backend batch the inference internally.
|
// blocking hop, letting the backend batch the inference internally.
|
||||||
|
let batch_len = chunk_candidates.len();
|
||||||
let embeddings = self
|
let embeddings = self
|
||||||
.embedding_provider
|
.embedding_provider
|
||||||
.embed_batch(chunk_candidates.clone())
|
.embed_batch(&chunk_candidates)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
AppError::InternalError(format!("FastEmbed embedding for chunks failed: {e}"))
|
AppError::InternalError(format!("FastEmbed embedding for chunks failed: {e}"))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if embeddings.len() != chunk_candidates.len() {
|
if embeddings.len() != batch_len {
|
||||||
return Err(AppError::InternalError(format!(
|
return Err(AppError::InternalError(format!(
|
||||||
"embedding batch returned {} vectors for {} chunks",
|
"embedding batch returned {} vectors for {} chunks",
|
||||||
embeddings.len(),
|
embeddings.len(),
|
||||||
chunk_candidates.len()
|
batch_len
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut chunks = Vec::with_capacity(chunk_candidates.len());
|
let mut chunks = Vec::with_capacity(batch_len);
|
||||||
for (chunk_text, embedding) in chunk_candidates.into_iter().zip(embeddings) {
|
for (chunk_text, embedding) in chunk_candidates.into_iter().zip(embeddings) {
|
||||||
let chunk_struct = TextChunk::new(
|
let chunk_struct = TextChunk::new(
|
||||||
content.id().to_string(),
|
content.id().to_string(),
|
||||||
|
|||||||
@@ -91,10 +91,10 @@ impl MockServices {
|
|||||||
similar_entities: vec![RetrievedEntity {
|
similar_entities: vec![RetrievedEntity {
|
||||||
entity: retrieved_entity,
|
entity: retrieved_entity,
|
||||||
score: 0.8,
|
score: 0.8,
|
||||||
chunks: vec![RetrievedChunk {
|
chunks: std::sync::Arc::new(vec![RetrievedChunk {
|
||||||
chunk: retrieved_chunk,
|
chunk: retrieved_chunk,
|
||||||
score: 0.7,
|
score: 0.7,
|
||||||
}],
|
}]),
|
||||||
}],
|
}],
|
||||||
analysis,
|
analysis,
|
||||||
chunk_embedding: vec![0.3; TEST_EMBEDDING_DIM],
|
chunk_embedding: vec![0.3; TEST_EMBEDDING_DIM],
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ pub mod query;
|
|||||||
pub mod reranking;
|
pub mod reranking;
|
||||||
pub mod scoring;
|
pub mod scoring;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
error::AppError,
|
error::AppError,
|
||||||
storage::{
|
storage::{
|
||||||
@@ -52,7 +54,7 @@ pub struct RetrievedChunk {
|
|||||||
pub struct RetrievedEntity {
|
pub struct RetrievedEntity {
|
||||||
pub entity: KnowledgeEntity,
|
pub entity: KnowledgeEntity,
|
||||||
pub score: f32,
|
pub score: f32,
|
||||||
pub chunks: Vec<RetrievedChunk>,
|
pub chunks: Arc<Vec<RetrievedChunk>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
|
/// Run chunk-first hybrid retrieval for `input_text`, optionally resolving owning entities.
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use common::{
|
|||||||
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||||
};
|
};
|
||||||
use fastembed::RerankResult;
|
use fastembed::RerankResult;
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, fmt::Write, sync::Arc};
|
||||||
use tracing::{debug, instrument, warn};
|
use tracing::{debug, instrument, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -106,7 +106,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
|||||||
#[instrument(level = "trace", skip_all)]
|
#[instrument(level = "trace", skip_all)]
|
||||||
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
pub async fn search_chunks(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||||
debug!("Collecting chunk candidates via vector and FTS search");
|
debug!("Collecting chunk candidates via vector and FTS search");
|
||||||
let embedding = ctx.ensure_embedding().map_err(|e| *e)?.clone();
|
let embedding = ctx.ensure_embedding().map_err(|e| *e)?;
|
||||||
let tuning = &ctx.config.tuning;
|
let tuning = &ctx.config.tuning;
|
||||||
let fts_take = tuning.chunk_fts_take;
|
let fts_take = tuning.chunk_fts_take;
|
||||||
let (fts_query, fts_token_count) = normalize_fts_terms(&ctx.input_text);
|
let (fts_query, fts_token_count) = normalize_fts_terms(&ctx.input_text);
|
||||||
@@ -233,12 +233,16 @@ pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppEr
|
|||||||
let mut best_score: HashMap<String, f32> = HashMap::new();
|
let mut best_score: HashMap<String, f32> = HashMap::new();
|
||||||
|
|
||||||
for scored in &ctx.chunk_values {
|
for scored in &ctx.chunk_values {
|
||||||
let source = scored.item.source_id.clone();
|
let source_id = &scored.item.source_id;
|
||||||
let attached = chunks_by_source.entry(source.clone()).or_default();
|
let is_new_source = !chunks_by_source.contains_key(source_id);
|
||||||
if attached.is_empty() {
|
if is_new_source {
|
||||||
source_order.push(source.clone());
|
source_order.push(source_id.clone());
|
||||||
best_score.insert(source.clone(), scored.fused);
|
best_score.insert(source_id.clone(), scored.fused);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let attached = chunks_by_source
|
||||||
|
.entry(source_id.clone())
|
||||||
|
.or_default();
|
||||||
if attached.len() < max_chunks {
|
if attached.len() < max_chunks {
|
||||||
attached.push(RetrievedChunk {
|
attached.push(RetrievedChunk {
|
||||||
chunk: scored.item.clone(),
|
chunk: scored.item.clone(),
|
||||||
@@ -247,6 +251,11 @@ pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppEr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let chunks_by_source: HashMap<String, Arc<Vec<RetrievedChunk>>> = chunks_by_source
|
||||||
|
.into_iter()
|
||||||
|
.map(|(source, chunks)| (source, Arc::new(chunks)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
let entities =
|
let entities =
|
||||||
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
KnowledgeEntity::find_by_source_ids(ctx.db_client, &source_order, &ctx.user_id).await?;
|
||||||
|
|
||||||
@@ -264,12 +273,15 @@ pub async fn resolve_entities(ctx: &mut PipelineContext<'_>) -> Result<(), AppEr
|
|||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
let score = best_score.get(source).copied().unwrap_or(0.0);
|
let score = best_score.get(source).copied().unwrap_or(0.0);
|
||||||
let chunks = chunks_by_source.get(source).cloned().unwrap_or_default();
|
let chunks = chunks_by_source
|
||||||
|
.get(source)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| Arc::new(Vec::new()));
|
||||||
for entity in entities {
|
for entity in entities {
|
||||||
results.push(RetrievedEntity {
|
results.push(RetrievedEntity {
|
||||||
entity,
|
entity,
|
||||||
score,
|
score,
|
||||||
chunks: chunks.clone(),
|
chunks: Arc::clone(&chunks),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -328,17 +340,26 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
fn build_chunk_rerank_documents(chunks: &[Scored<TextChunk>], max_chunks: usize) -> Vec<String> {
|
||||||
chunks
|
let take = chunks.len().min(max_chunks);
|
||||||
.iter()
|
let mut documents = Vec::with_capacity(take);
|
||||||
.take(max_chunks)
|
let mut buffer = String::with_capacity(512);
|
||||||
.map(|chunk| {
|
|
||||||
format!(
|
for chunk in chunks.iter().take(max_chunks) {
|
||||||
"Source: {}\nChunk:\n{}",
|
buffer.clear();
|
||||||
chunk.item.source_id,
|
let _ = write!(
|
||||||
chunk.item.chunk.trim()
|
buffer,
|
||||||
)
|
"Source: {}\nChunk:\n{}",
|
||||||
})
|
chunk.item.source_id,
|
||||||
.collect()
|
chunk.item.chunk.trim()
|
||||||
|
);
|
||||||
|
let next_capacity = buffer.capacity().max(512);
|
||||||
|
documents.push(std::mem::replace(
|
||||||
|
&mut buffer,
|
||||||
|
String::with_capacity(next_capacity),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
documents
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_chunk_rerank_results(
|
fn apply_chunk_rerank_results(
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
use std::{cmp::Ordering, collections::HashMap};
|
use std::{
|
||||||
|
cmp::Ordering,
|
||||||
|
collections::{hash_map::Entry, HashMap},
|
||||||
|
};
|
||||||
|
|
||||||
use common::storage::types::StoredObject;
|
use common::storage::types::StoredObject;
|
||||||
|
|
||||||
@@ -119,7 +122,7 @@ pub fn reciprocal_rank_fusion<T>(
|
|||||||
config: RrfConfig,
|
config: RrfConfig,
|
||||||
) -> Vec<Scored<T>>
|
) -> Vec<Scored<T>>
|
||||||
where
|
where
|
||||||
T: StoredObject + Clone,
|
T: StoredObject,
|
||||||
{
|
{
|
||||||
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
|
let mut merged: HashMap<String, Scored<T>> = HashMap::new();
|
||||||
let k = if config.k <= 0.0 { 60.0 } else { config.k };
|
let k = if config.k <= 0.0 { 60.0 } else { config.k };
|
||||||
@@ -146,19 +149,30 @@ where
|
|||||||
|
|
||||||
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
|
for (rank, candidate) in vector_ranked.into_iter().enumerate() {
|
||||||
let id = candidate.item.id().to_owned();
|
let id = candidate.item.id().to_owned();
|
||||||
let entry = merged
|
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||||
.entry(id.clone())
|
let contribution = vector_weight / (k + rank_f32 + 1.0);
|
||||||
.or_insert_with(|| Scored::new(candidate.item.clone()));
|
|
||||||
|
|
||||||
if let Some(score) = candidate.scores.vector {
|
match merged.entry(id) {
|
||||||
let existing = entry.scores.vector.unwrap_or(f32::MIN);
|
Entry::Occupied(mut occupied) => {
|
||||||
if score > existing {
|
let entry = occupied.get_mut();
|
||||||
entry.scores.vector = Some(score);
|
if let Some(score) = candidate.scores.vector {
|
||||||
|
let existing = entry.scores.vector.unwrap_or(f32::MIN);
|
||||||
|
if score > existing {
|
||||||
|
entry.scores.vector = Some(score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entry.item = candidate.item;
|
||||||
|
entry.fused += contribution;
|
||||||
|
}
|
||||||
|
Entry::Vacant(vacant) => {
|
||||||
|
let mut scored = Scored::new(candidate.item);
|
||||||
|
if let Some(score) = candidate.scores.vector {
|
||||||
|
scored.scores.vector = Some(score);
|
||||||
|
}
|
||||||
|
scored.fused = contribution;
|
||||||
|
vacant.insert(scored);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
entry.item = candidate.item;
|
|
||||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
|
||||||
entry.fused += vector_weight / (k + rank_f32 + 1.0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,19 +188,30 @@ where
|
|||||||
|
|
||||||
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
|
for (rank, candidate) in fts_ranked.into_iter().enumerate() {
|
||||||
let id = candidate.item.id().to_owned();
|
let id = candidate.item.id().to_owned();
|
||||||
let entry = merged
|
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
||||||
.entry(id.clone())
|
let contribution = fts_weight / (k + rank_f32 + 1.0);
|
||||||
.or_insert_with(|| Scored::new(candidate.item.clone()));
|
|
||||||
|
|
||||||
if let Some(score) = candidate.scores.fts {
|
match merged.entry(id) {
|
||||||
let existing = entry.scores.fts.unwrap_or(f32::MIN);
|
Entry::Occupied(mut occupied) => {
|
||||||
if score > existing {
|
let entry = occupied.get_mut();
|
||||||
entry.scores.fts = Some(score);
|
if let Some(score) = candidate.scores.fts {
|
||||||
|
let existing = entry.scores.fts.unwrap_or(f32::MIN);
|
||||||
|
if score > existing {
|
||||||
|
entry.scores.fts = Some(score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
entry.item = candidate.item;
|
||||||
|
entry.fused += contribution;
|
||||||
|
}
|
||||||
|
Entry::Vacant(vacant) => {
|
||||||
|
let mut scored = Scored::new(candidate.item);
|
||||||
|
if let Some(score) = candidate.scores.fts {
|
||||||
|
scored.scores.fts = Some(score);
|
||||||
|
}
|
||||||
|
scored.fused = contribution;
|
||||||
|
vacant.insert(scored);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
entry.item = candidate.item;
|
|
||||||
let rank_f32: f32 = u16::try_from(rank).map_or(f32::MAX, f32::from);
|
|
||||||
entry.fused += fts_weight / (k + rank_f32 + 1.0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user