mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-26 03:46:24 +02:00
fix: replaced several instances if cloning, reduced allocations
This commit is contained in:
@@ -258,7 +258,7 @@ impl KnowledgeEntity {
|
||||
/// Vector search over knowledge entities using the embedding table, fetching full entity rows and scores.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
query_embedding: &[f32],
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntitySearchResult>, AppError> {
|
||||
@@ -286,7 +286,7 @@ impl KnowledgeEntity {
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
@@ -408,7 +408,7 @@ impl KnowledgeEntity {
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let embeddings = provider.embed_batch(inputs).await?;
|
||||
let embeddings = provider.embed_batch(&inputs).await?;
|
||||
if embeddings.len() != batch.len() {
|
||||
return Err(AppError::internal(format!(
|
||||
"embedding batch returned {} vectors for {} entities",
|
||||
@@ -817,7 +817,7 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
|
||||
let results = KnowledgeEntity::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
let results = KnowledgeEntity::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.expect("vector search");
|
||||
assert!(results.is_empty());
|
||||
@@ -878,7 +878,7 @@ mod tests {
|
||||
.with_context(|| "fetch embedding".to_string())?;
|
||||
assert!(fetched_emb.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
let results = KnowledgeEntity::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
|
||||
@@ -965,7 +965,7 @@ mod tests {
|
||||
.with_context(|| "get embedding e2".to_string())?
|
||||
.is_some());
|
||||
|
||||
let results = KnowledgeEntity::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
let results = KnowledgeEntity::vector_search(2, &[0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.with_context(|| "vector search".to_string())?;
|
||||
|
||||
@@ -1030,7 +1030,7 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "delete entity".to_string())?;
|
||||
|
||||
let results = KnowledgeEntity::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
let results = KnowledgeEntity::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.with_context(|| "search should succeed even with orphans".to_string())?;
|
||||
|
||||
|
||||
@@ -42,12 +42,24 @@ impl KnowledgeRelationship {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
pub async fn store_relationship(self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client)
|
||||
.await?;
|
||||
User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client)
|
||||
.await?;
|
||||
|
||||
let Self {
|
||||
id,
|
||||
in_,
|
||||
out,
|
||||
metadata:
|
||||
RelationshipMetadata {
|
||||
user_id,
|
||||
source_id,
|
||||
relationship_type,
|
||||
},
|
||||
} = self;
|
||||
|
||||
db_client
|
||||
.client
|
||||
.query(
|
||||
@@ -62,12 +74,12 @@ impl KnowledgeRelationship {
|
||||
metadata.relationship_type = $relationship_type;
|
||||
COMMIT TRANSACTION;"#,
|
||||
)
|
||||
.bind(("rel_id", self.id.clone()))
|
||||
.bind(("in_id", self.in_.clone()))
|
||||
.bind(("out_id", self.out.clone()))
|
||||
.bind(("user_id", self.metadata.user_id.clone()))
|
||||
.bind(("source_id", self.metadata.source_id.clone()))
|
||||
.bind(("relationship_type", self.metadata.relationship_type.clone()))
|
||||
.bind(("rel_id", id))
|
||||
.bind(("in_id", in_))
|
||||
.bind(("out_id", out))
|
||||
.bind(("user_id", user_id))
|
||||
.bind(("source_id", source_id))
|
||||
.bind(("relationship_type", relationship_type))
|
||||
.await
|
||||
.map_err(AppError::from)?
|
||||
.check()
|
||||
@@ -230,13 +242,14 @@ mod tests {
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
let relationship_id = relationship.id.clone();
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.with_context(|| "Failed to store relationship".to_string())?;
|
||||
|
||||
let persisted = get_relationship_by_id(&relationship.id, &db)
|
||||
let persisted = get_relationship_by_id(&relationship_id, &db)
|
||||
.await
|
||||
.expect("Relationship should be retrievable by id");
|
||||
assert_eq!(persisted.in_, entity1_id);
|
||||
@@ -296,6 +309,7 @@ mod tests {
|
||||
"source123'; DELETE FROM relates_to; --".to_string(),
|
||||
"references'; UPDATE user SET admin = true; --".to_string(),
|
||||
);
|
||||
let relationship_id = relationship.id.clone();
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
@@ -305,7 +319,7 @@ mod tests {
|
||||
let mut res = db
|
||||
.client
|
||||
.query("SELECT * FROM relates_to WHERE id = type::thing('relates_to', $id)")
|
||||
.bind(("id", relationship.id.clone()))
|
||||
.bind(("id", relationship_id))
|
||||
.await
|
||||
.expect("query relationship by id failed");
|
||||
let rows: Vec<KnowledgeRelationship> = res.take(0).expect("take rows");
|
||||
@@ -338,6 +352,7 @@ mod tests {
|
||||
source_id.clone(),
|
||||
relationship_type,
|
||||
);
|
||||
let relationship_id = relationship.id.clone();
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
@@ -357,7 +372,7 @@ mod tests {
|
||||
"Relationship should exist before deletion"
|
||||
);
|
||||
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship.id, user_id, &db)
|
||||
KnowledgeRelationship::delete_relationship_by_id(&relationship_id, user_id, &db)
|
||||
.await
|
||||
.with_context(|| "Failed to delete relationship by ID".to_string())?;
|
||||
|
||||
@@ -391,6 +406,7 @@ mod tests {
|
||||
source_id,
|
||||
"references".to_string(),
|
||||
);
|
||||
let relationship_id = relationship.id.clone();
|
||||
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
@@ -409,7 +425,7 @@ mod tests {
|
||||
);
|
||||
|
||||
let result = KnowledgeRelationship::delete_relationship_by_id(
|
||||
&relationship.id,
|
||||
&relationship_id,
|
||||
"different-user",
|
||||
&db,
|
||||
)
|
||||
@@ -472,6 +488,9 @@ mod tests {
|
||||
different_source_id.clone(),
|
||||
"mentions".to_string(),
|
||||
);
|
||||
let relationship1_id = relationship1.id.clone();
|
||||
let relationship2_id = relationship2.id.clone();
|
||||
let different_relationship_id = different_relationship.id.clone();
|
||||
|
||||
relationship1
|
||||
.store_relationship(&db)
|
||||
@@ -508,9 +527,9 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "Failed to delete relationships by source_id".to_string())?;
|
||||
|
||||
let result1 = get_relationship_by_id(&relationship1.id, &db).await;
|
||||
let result2 = get_relationship_by_id(&relationship2.id, &db).await;
|
||||
let different_result = get_relationship_by_id(&different_relationship.id, &db).await;
|
||||
let result1 = get_relationship_by_id(&relationship1_id, &db).await;
|
||||
let result2 = get_relationship_by_id(&relationship2_id, &db).await;
|
||||
let different_result = get_relationship_by_id(&different_relationship_id, &db).await;
|
||||
|
||||
assert!(result1.is_none(), "Relationship 1 should be deleted");
|
||||
assert!(result2.is_none(), "Relationship 2 should be deleted");
|
||||
@@ -548,6 +567,8 @@ mod tests {
|
||||
shared_source.to_string(),
|
||||
"references".to_string(),
|
||||
);
|
||||
let rel_a_id = rel_a.id.clone();
|
||||
let rel_b_id = rel_b.id.clone();
|
||||
|
||||
rel_a.store_relationship(&db).await?;
|
||||
rel_b.store_relationship(&db).await?;
|
||||
@@ -555,8 +576,8 @@ mod tests {
|
||||
KnowledgeRelationship::delete_relationships_by_source_id(shared_source, user_a, &db)
|
||||
.await?;
|
||||
|
||||
assert!(get_relationship_by_id(&rel_a.id, &db).await.is_none());
|
||||
assert!(get_relationship_by_id(&rel_b.id, &db).await.is_some());
|
||||
assert!(get_relationship_by_id(&rel_a_id, &db).await.is_none());
|
||||
assert!(get_relationship_by_id(&rel_b_id, &db).await.is_some());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -586,6 +607,8 @@ mod tests {
|
||||
"other_source".to_string(),
|
||||
"contains".to_string(),
|
||||
);
|
||||
let safe_relationship_id = safe_relationship.id.clone();
|
||||
let other_relationship_id = other_relationship.id.clone();
|
||||
|
||||
safe_relationship
|
||||
.store_relationship(&db)
|
||||
@@ -604,8 +627,8 @@ mod tests {
|
||||
.await
|
||||
.expect("delete call should succeed");
|
||||
|
||||
let remaining_safe = get_relationship_by_id(&safe_relationship.id, &db).await;
|
||||
let remaining_other = get_relationship_by_id(&other_relationship.id, &db).await;
|
||||
let remaining_safe = get_relationship_by_id(&safe_relationship_id, &db).await;
|
||||
let remaining_other = get_relationship_by_id(&other_relationship_id, &db).await;
|
||||
|
||||
assert!(remaining_safe.is_some(), "Safe relationship should remain");
|
||||
assert!(
|
||||
|
||||
@@ -107,7 +107,7 @@ impl TextChunk {
|
||||
/// Vector search over text chunks using the embedding table, fetching full chunk rows and embeddings.
|
||||
pub async fn vector_search(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
query_embedding: &[f32],
|
||||
db: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<TextChunkSearchResult>, AppError> {
|
||||
@@ -137,7 +137,7 @@ impl TextChunk {
|
||||
|
||||
let mut response = db
|
||||
.query(&sql)
|
||||
.bind(("embedding", query_embedding))
|
||||
.bind(("embedding", query_embedding.to_vec()))
|
||||
.bind(("user_id", user_id.to_string()))
|
||||
.await
|
||||
.map_err(AppError::from)?;
|
||||
@@ -273,7 +273,7 @@ impl TextChunk {
|
||||
let mut processed = 0usize;
|
||||
for batch in all_chunks.chunks(RE_EMBED_BATCH_SIZE) {
|
||||
let inputs: Vec<String> = batch.iter().map(|chunk| chunk.chunk.clone()).collect();
|
||||
let embeddings = provider.embed_batch(inputs).await?;
|
||||
let embeddings = provider.embed_batch(&inputs).await?;
|
||||
if embeddings.len() != batch.len() {
|
||||
return Err(AppError::internal(format!(
|
||||
"embedding batch returned {} vectors for {} chunks",
|
||||
@@ -720,7 +720,7 @@ mod tests {
|
||||
.with_context(|| "redefine index".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(5, vec![0.1, 0.2, 0.3], &db, "user")
|
||||
TextChunk::vector_search(5, &[0.1, 0.2, 0.3], &db, "user")
|
||||
.await
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
assert!(results.is_empty());
|
||||
@@ -756,7 +756,7 @@ mod tests {
|
||||
.with_context(|| "store".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
TextChunk::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
|
||||
@@ -796,7 +796,7 @@ mod tests {
|
||||
.with_context(|| "store chunk2".to_string())?;
|
||||
|
||||
let results: Vec<TextChunkSearchResult> =
|
||||
TextChunk::vector_search(2, vec![0.0, 1.0, 0.0], &db, &user_id)
|
||||
TextChunk::vector_search(2, &[0.0, 1.0, 0.0], &db, &user_id)
|
||||
.await
|
||||
.with_context(|| "vector_search".to_string())?;
|
||||
|
||||
@@ -987,7 +987,7 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "delete chunk".to_string())?;
|
||||
|
||||
let results = TextChunk::vector_search(3, vec![0.1, 0.2, 0.3], &db, &user_id)
|
||||
let results = TextChunk::vector_search(3, &[0.1, 0.2, 0.3], &db, &user_id)
|
||||
.await
|
||||
.with_context(|| "search should succeed even with orphans".to_string())?;
|
||||
|
||||
|
||||
@@ -372,17 +372,17 @@ impl EmbeddingProvider {
|
||||
///
|
||||
/// Returns [`EmbeddingError`] if the backend API call fails or returns no embedding data.
|
||||
/// Returns an empty `Vec` when `texts` is empty.
|
||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
pub async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match &self.inner {
|
||||
EmbeddingInner::Hashed { dimension } => Ok(texts
|
||||
.into_iter()
|
||||
.map(|text| hashed_embedding(&text, *dimension))
|
||||
.iter()
|
||||
.map(|text| hashed_embedding(text, *dimension))
|
||||
.collect()),
|
||||
EmbeddingInner::FastEmbed { pool, .. } => {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
run_fastembed(pool, texts).await
|
||||
run_fastembed(pool, texts.to_vec()).await
|
||||
}
|
||||
EmbeddingInner::OpenAI {
|
||||
client,
|
||||
@@ -395,7 +395,7 @@ impl EmbeddingProvider {
|
||||
|
||||
let request = CreateEmbeddingRequestArgs::default()
|
||||
.model(model.clone())
|
||||
.input(texts)
|
||||
.input(texts.to_vec())
|
||||
.dimensions(*dimensions)
|
||||
.build()?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user