mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-18 12:39:38 +02:00
chore: ingestion-pipeline refactor, sort technical debt, rustfmt
This commit is contained in:
@@ -28,7 +28,11 @@ pub fn hnsw_index_overwrite_sql(index_name: &str, table: &str, dimension: usize)
|
||||
|
||||
/// Recreates an HNSW index inside a transaction (for tests and dimension migrations).
|
||||
#[must_use]
|
||||
pub fn hnsw_index_redefine_transaction_sql(index_name: &str, table: &str, dimension: usize) -> String {
|
||||
pub fn hnsw_index_redefine_transaction_sql(
|
||||
index_name: &str,
|
||||
table: &str,
|
||||
dimension: usize,
|
||||
) -> String {
|
||||
format!(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS {index_name} ON TABLE {table};
|
||||
@@ -204,9 +208,7 @@ pub async fn ensure_runtime(
|
||||
///
|
||||
/// Returns `AppError::InternalError` if any index rebuild operation fails.
|
||||
pub async fn rebuild(db: &SurrealDbClient) -> Result<(), AppError> {
|
||||
rebuild_inner(db)
|
||||
.await
|
||||
.map_err(AppError::internal)
|
||||
rebuild_inner(db).await.map_err(AppError::internal)
|
||||
}
|
||||
|
||||
async fn ensure_runtime_inner(db: &SurrealDbClient, embedding_dimension: usize) -> Result<()> {
|
||||
@@ -297,8 +299,8 @@ async fn get_index_status(db: &SurrealDbClient, index_name: &str, table: &str) -
|
||||
return Ok("unknown".to_string());
|
||||
};
|
||||
|
||||
let parsed: IndexInfoForIndex = serde_json::from_value(info)
|
||||
.context("deserializing INFO FOR INDEX response")?;
|
||||
let parsed: IndexInfoForIndex =
|
||||
serde_json::from_value(info).context("deserializing INFO FOR INDEX response")?;
|
||||
|
||||
Ok(parsed.building_status())
|
||||
}
|
||||
|
||||
@@ -518,8 +518,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_sidebar_conversation_deserializes_plain_string_id() {
|
||||
let item: SidebarConversation = serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
|
||||
.expect("valid sidebar conversation json");
|
||||
let item: SidebarConversation =
|
||||
serde_json::from_str(r#"{"id":"conv-plain","title":"My chat"}"#)
|
||||
.expect("valid sidebar conversation json");
|
||||
assert_eq!(item.id, "conv-plain");
|
||||
assert_eq!(item.title, "My chat");
|
||||
}
|
||||
@@ -568,8 +569,7 @@ mod tests {
|
||||
))
|
||||
.await?;
|
||||
|
||||
let owner_messages =
|
||||
fetch_messages_for_owner(&db, &conversation_id, owner).await?;
|
||||
let owner_messages = fetch_messages_for_owner(&db, &conversation_id, owner).await?;
|
||||
assert_eq!(owner_messages.len(), 1);
|
||||
assert_eq!(
|
||||
owner_messages
|
||||
@@ -579,8 +579,7 @@ mod tests {
|
||||
"secret message"
|
||||
);
|
||||
|
||||
let intruder_messages =
|
||||
fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
|
||||
let intruder_messages = fetch_messages_for_owner(&db, &conversation_id, intruder).await?;
|
||||
assert!(
|
||||
intruder_messages.is_empty(),
|
||||
"SQL owner filter must not return messages for a non-owner user_id"
|
||||
|
||||
@@ -205,13 +205,9 @@ impl FileInfo {
|
||||
let now = Utc::now();
|
||||
let storage_prefix = format!("{user_id}/{uuid}");
|
||||
|
||||
let path = Self::persist_bytes_with_storage(
|
||||
&storage_prefix,
|
||||
&sanitized_file_name,
|
||||
bytes,
|
||||
storage,
|
||||
)
|
||||
.await?;
|
||||
let path =
|
||||
Self::persist_bytes_with_storage(&storage_prefix, &sanitized_file_name, bytes, storage)
|
||||
.await?;
|
||||
|
||||
let file_info = FileInfo {
|
||||
id: uuid.to_string(),
|
||||
@@ -262,8 +258,8 @@ impl FileInfo {
|
||||
};
|
||||
|
||||
// Remove the object's parent prefix in the object store
|
||||
let (parent_prefix, _file_name) = store::split_object_path(&file_info.path)
|
||||
.map_err(AppError::internal)?;
|
||||
let (parent_prefix, _file_name) =
|
||||
store::split_object_path(&file_info.path).map_err(AppError::internal)?;
|
||||
storage
|
||||
.delete_prefix(&parent_prefix)
|
||||
.await
|
||||
@@ -290,10 +286,7 @@ impl FileInfo {
|
||||
&self,
|
||||
storage: &StorageManager,
|
||||
) -> Result<bytes::Bytes, AppError> {
|
||||
storage
|
||||
.get(&self.path)
|
||||
.await
|
||||
.map_err(AppError::Storage)
|
||||
storage.get(&self.path).await.map_err(AppError::Storage)
|
||||
}
|
||||
|
||||
/// Persist bytes to storage using StorageManager.
|
||||
|
||||
@@ -26,6 +26,19 @@ pub enum IngestionPayload {
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for IngestionPayload {
|
||||
/// An empty text payload, used as a cheap placeholder when the real content
|
||||
/// has been moved out of a task (see [`crate::storage::types::ingestion_task::IngestionTask::take_content`]).
|
||||
fn default() -> Self {
|
||||
Self::Text {
|
||||
text: String::new(),
|
||||
context: String::new(),
|
||||
category: String::new(),
|
||||
user_id: String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared ingest metadata moved or cloned into each payload variant.
|
||||
struct IngestFields {
|
||||
context: String,
|
||||
@@ -440,7 +453,8 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_ingestion_payload_short_content_with_file_only_yields_file() -> anyhow::Result<()> {
|
||||
fn test_create_ingestion_payload_short_content_with_file_only_yields_file() -> anyhow::Result<()>
|
||||
{
|
||||
let context = "ctx";
|
||||
let category = "cat";
|
||||
let user_id = "user123";
|
||||
@@ -501,7 +515,10 @@ mod tests {
|
||||
)?;
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert!(matches!(result.first(), Some(IngestionPayload::File { .. })));
|
||||
assert!(matches!(
|
||||
result.first(),
|
||||
Some(IngestionPayload::File { .. })
|
||||
));
|
||||
assert!(matches!(result.get(1), Some(IngestionPayload::File { .. })));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -211,6 +211,16 @@ impl IngestionTask {
|
||||
self.attempts < self.max_attempts
|
||||
}
|
||||
|
||||
/// Moves the payload out of the task, leaving an empty placeholder behind.
|
||||
///
|
||||
/// The task's `content` is only needed while driving the pipeline; the
|
||||
/// terminal `user_id`, `state`, and bookkeeping fields are stored separately,
|
||||
/// so replacing it with the default placeholder avoids cloning large payloads.
|
||||
#[must_use]
|
||||
pub fn take_content(&mut self) -> IngestionPayload {
|
||||
std::mem::take(&mut self.content)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn lease_duration(&self) -> Duration {
|
||||
Duration::from_secs(u64::try_from(self.lease_duration_secs.max(0)).unwrap_or(0))
|
||||
@@ -654,6 +664,18 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_take_content_moves_payload_and_leaves_default() {
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let mut task = IngestionTask::new(payload.clone(), user_id.to_string());
|
||||
|
||||
let taken = task.take_content();
|
||||
|
||||
assert_eq!(taken, payload);
|
||||
assert_eq!(task.content, IngestionPayload::default());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_all_and_add_to_db_empty() -> anyhow::Result<()> {
|
||||
let db = memory_db().await?;
|
||||
@@ -676,8 +698,7 @@ mod tests {
|
||||
},
|
||||
];
|
||||
|
||||
let created =
|
||||
IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
|
||||
let created = IngestionTask::create_all_and_add_to_db(payloads, user_id, &db).await?;
|
||||
|
||||
assert_eq!(created.len(), 2);
|
||||
let first = created.first().expect("expected first task");
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
#![allow(
|
||||
clippy::missing_docs_in_private_items,
|
||||
clippy::module_name_repetitions,
|
||||
)]
|
||||
#![allow(clippy::missing_docs_in_private_items, clippy::module_name_repetitions)]
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
|
||||
use crate::{
|
||||
error::AppError, storage::db::SurrealDbClient,
|
||||
storage::indexes::hnsw_index_overwrite_sql,
|
||||
error::AppError, storage::db::SurrealDbClient, storage::indexes::hnsw_index_overwrite_sql,
|
||||
storage::types::knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
storage::types::system_settings::SystemSettings, stored_object,
|
||||
utils::embedding::generate_embedding,
|
||||
utils::embedding::{generate_embedding_with_params, generate_embedding_with_provider, EmbeddingProvider},
|
||||
};
|
||||
use async_openai::{config::OpenAIConfig, Client};
|
||||
use tokio_retry::{
|
||||
@@ -315,12 +311,12 @@ impl KnowledgeEntity {
|
||||
description: &str,
|
||||
entity_type: &KnowledgeEntityType,
|
||||
db_client: &SurrealDbClient,
|
||||
ai_client: &Client<OpenAIConfig>,
|
||||
embedding_provider: &EmbeddingProvider,
|
||||
) -> Result<(), AppError> {
|
||||
let embedding_input = format!(
|
||||
"name: {name}, description: {description}, type: {entity_type:?}",
|
||||
);
|
||||
let embedding = generate_embedding(ai_client, &embedding_input, db_client).await?;
|
||||
let embedding_input =
|
||||
format!("name: {name}, description: {description}, type: {entity_type:?}",);
|
||||
let embedding =
|
||||
generate_embedding_with_provider(embedding_provider, &embedding_input).await?;
|
||||
|
||||
let entity: KnowledgeEntity = db_client
|
||||
.get_item(id)
|
||||
@@ -334,12 +330,7 @@ impl KnowledgeEntity {
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
|
||||
let emb = KnowledgeEntityEmbedding::new(
|
||||
id,
|
||||
entity.source_id,
|
||||
embedding,
|
||||
entity.user_id,
|
||||
);
|
||||
let emb = KnowledgeEntityEmbedding::new(id, entity.source_id, embedding, entity.user_id);
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
@@ -411,7 +402,7 @@ impl KnowledgeEntity {
|
||||
let retry_strategy = ExponentialBackoff::from_millis(100).map(jitter).take(3);
|
||||
|
||||
let embedding = Retry::spawn(retry_strategy, || {
|
||||
crate::utils::embedding::generate_embedding_with_params(
|
||||
generate_embedding_with_params(
|
||||
openai_client,
|
||||
&embedding_input,
|
||||
new_model,
|
||||
@@ -431,11 +422,7 @@ impl KnowledgeEntity {
|
||||
}
|
||||
new_embeddings.insert(
|
||||
entity.id.clone(),
|
||||
(
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
entity.source_id.clone(),
|
||||
),
|
||||
(embedding, entity.user_id.clone(), entity.source_id.clone()),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
@@ -446,9 +433,8 @@ impl KnowledgeEntity {
|
||||
|
||||
// Add all update statements to the embedding table
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
|
||||
write!(
|
||||
transaction_query,
|
||||
"UPSERT type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
@@ -526,9 +512,7 @@ impl KnowledgeEntity {
|
||||
entity.name, entity.description, entity.entity_type
|
||||
);
|
||||
|
||||
let embedding = provider
|
||||
.embed(&embedding_input)
|
||||
.await?;
|
||||
let embedding = provider.embed(&embedding_input).await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
@@ -541,11 +525,7 @@ impl KnowledgeEntity {
|
||||
}
|
||||
new_embeddings.insert(
|
||||
entity.id.clone(),
|
||||
(
|
||||
embedding,
|
||||
entity.user_id.clone(),
|
||||
entity.source_id.clone(),
|
||||
),
|
||||
(embedding, entity.user_id.clone(), entity.source_id.clone()),
|
||||
);
|
||||
}
|
||||
info!("Successfully generated all new embeddings.");
|
||||
@@ -580,9 +560,8 @@ impl KnowledgeEntity {
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
|
||||
write!(
|
||||
transaction_query,
|
||||
"CREATE type::thing('knowledge_entity_embedding', '{id}') SET \
|
||||
@@ -833,7 +812,9 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
configure_embedding_dimension(&db, 3).await.expect("configure dim");
|
||||
configure_embedding_dimension(&db, 3)
|
||||
.await
|
||||
.expect("configure dim");
|
||||
KnowledgeEntityEmbedding::redefine_hnsw_index(&db, 3)
|
||||
.await
|
||||
.expect("Failed to redefine index length");
|
||||
@@ -1105,7 +1086,10 @@ mod tests {
|
||||
.await
|
||||
.with_context(|| "store entity with embedding".to_string())?;
|
||||
|
||||
let query = format!("DELETE type::thing('knowledge_entity', '{id}')", id = entity.id);
|
||||
let query = format!(
|
||||
"DELETE type::thing('knowledge_entity', '{id}')",
|
||||
id = entity.id
|
||||
);
|
||||
db.client
|
||||
.query(query)
|
||||
.await
|
||||
|
||||
@@ -51,12 +51,7 @@ impl KnowledgeEntityEmbedding {
|
||||
///
|
||||
/// The embedding record id equals `entity_id` so each entity has at most one embedding row.
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
entity_id: &str,
|
||||
source_id: String,
|
||||
embedding: Vec<f32>,
|
||||
user_id: String,
|
||||
) -> Self {
|
||||
pub fn new(entity_id: &str, source_id: String, embedding: Vec<f32>, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: entity_id.to_owned(),
|
||||
@@ -300,8 +295,7 @@ mod tests {
|
||||
let db = prepare_knowledge_entity_test_db(3).await?;
|
||||
|
||||
let entity = build_knowledge_entity_with_id("entity-dim", "source-dim", "user-dim");
|
||||
let result =
|
||||
KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
||||
let result = KnowledgeEntity::store_with_embedding(entity, vec![0.1, 0.2], &db).await;
|
||||
|
||||
assert!(matches!(result, Err(AppError::Validation(_))));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||
use crate::storage::types::user::User;
|
||||
use crate::utils::serde_helpers::deserialize_flexible_id;
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
@@ -43,18 +43,10 @@ impl KnowledgeRelationship {
|
||||
}
|
||||
|
||||
pub async fn store_relationship(&self, db_client: &SurrealDbClient) -> Result<(), AppError> {
|
||||
User::get_and_validate_knowledge_entity(
|
||||
&self.in_,
|
||||
&self.metadata.user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
User::get_and_validate_knowledge_entity(
|
||||
&self.out,
|
||||
&self.metadata.user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
User::get_and_validate_knowledge_entity(&self.in_, &self.metadata.user_id, db_client)
|
||||
.await?;
|
||||
User::get_and_validate_knowledge_entity(&self.out, &self.metadata.user_id, db_client)
|
||||
.await?;
|
||||
|
||||
db_client
|
||||
.client
|
||||
|
||||
@@ -520,7 +520,10 @@ mod tests {
|
||||
.with_context(|| "Failed to patch query prompt".to_string())?;
|
||||
|
||||
assert_eq!(patched.query_system_prompt, sentinel);
|
||||
assert_eq!(patched.ingestion_system_prompt, original.ingestion_system_prompt);
|
||||
assert_eq!(
|
||||
patched.ingestion_system_prompt,
|
||||
original.ingestion_system_prompt
|
||||
);
|
||||
assert_eq!(patched.query_model, original.query_model);
|
||||
assert_eq!(
|
||||
patched.registrations_enabled,
|
||||
|
||||
@@ -74,10 +74,7 @@ impl TextChunk {
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let settings = SystemSettings::get_current(db).await?;
|
||||
TextChunkEmbedding::validate_dimension(
|
||||
&embedding,
|
||||
settings.embedding_dimensions as usize,
|
||||
)?;
|
||||
TextChunkEmbedding::validate_dimension(&embedding, settings.embedding_dimensions as usize)?;
|
||||
|
||||
let chunk_id = chunk.id.clone();
|
||||
let emb = TextChunkEmbedding::new(
|
||||
@@ -308,9 +305,8 @@ impl TextChunk {
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
|
||||
let id = surql_json_string(&id)?;
|
||||
let user_id = surql_json_string(&user_id)?;
|
||||
let source_id = surql_json_string(&source_id)?;
|
||||
@@ -382,9 +378,7 @@ impl TextChunk {
|
||||
info!(progress = i, total = total_chunks, "Re-embedding progress");
|
||||
}
|
||||
|
||||
let embedding = provider
|
||||
.embed(&chunk.chunk)
|
||||
.await?;
|
||||
let embedding = provider.embed(&chunk.chunk).await?;
|
||||
|
||||
// Safety check: ensure the generated embedding has the correct dimension.
|
||||
if embedding.len() != new_dimensions {
|
||||
@@ -429,9 +423,8 @@ impl TextChunk {
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
|
||||
for (id, (embedding, user_id, source_id)) in new_embeddings {
|
||||
let embedding = serde_json::to_string(&embedding).map_err(|e| {
|
||||
AppError::internal(format!("embedding serialization failed: {e}"))
|
||||
})?;
|
||||
let embedding = serde_json::to_string(&embedding)
|
||||
.map_err(|e| AppError::internal(format!("embedding serialization failed: {e}")))?;
|
||||
let id = surql_json_string(&id)?;
|
||||
let user_id = surql_json_string(&user_id)?;
|
||||
let source_id = surql_json_string(&source_id)?;
|
||||
@@ -662,7 +655,9 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
db.apply_migrations().await.expect("migrations");
|
||||
configure_embedding_dimension(&db, 5).await.expect("configure dim");
|
||||
configure_embedding_dimension(&db, 5)
|
||||
.await
|
||||
.expect("configure dim");
|
||||
TextChunkEmbedding::redefine_hnsw_index(&db, 5)
|
||||
.await
|
||||
.expect("redefine index");
|
||||
@@ -1040,11 +1035,7 @@ mod tests {
|
||||
.with_context(|| "migrations".to_string())?;
|
||||
configure_embedding_dimension(&db, 3).await?;
|
||||
|
||||
let chunk = TextChunk::new(
|
||||
"src".to_string(),
|
||||
"body".to_string(),
|
||||
"user".to_string(),
|
||||
);
|
||||
let chunk = TextChunk::new("src".to_string(), "body".to_string(), "user".to_string());
|
||||
|
||||
let err = TextChunk::store_with_embedding(chunk, vec![0.1, 0.2], &db)
|
||||
.await
|
||||
|
||||
@@ -299,12 +299,7 @@ mod tests {
|
||||
("chunk-s2", source_id, vec![0.2]),
|
||||
("chunk-other", other_source, vec![0.3]),
|
||||
] {
|
||||
let emb = TextChunkEmbedding::new(
|
||||
key,
|
||||
src.to_string(),
|
||||
vec,
|
||||
user_id.to_string(),
|
||||
);
|
||||
let emb = TextChunkEmbedding::new(key, src.to_string(), vec, user_id.to_string());
|
||||
db.upsert_item(emb)
|
||||
.await
|
||||
.with_context(|| format!("store embedding for {key}"))?;
|
||||
|
||||
@@ -118,9 +118,7 @@ impl TextContent {
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
if updated.is_none() {
|
||||
return Err(AppError::NotFound(format!(
|
||||
"text content {id} not found"
|
||||
)));
|
||||
return Err(AppError::NotFound(format!("text content {id} not found")));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -142,7 +140,8 @@ impl TextContent {
|
||||
.await
|
||||
.map_err(AppError::Database)?;
|
||||
|
||||
let existing: Option<surrealdb::sql::Thing> = response.take(0).map_err(AppError::Database)?;
|
||||
let existing: Option<surrealdb::sql::Thing> =
|
||||
response.take(0).map_err(AppError::Database)?;
|
||||
|
||||
Ok(existing.is_some())
|
||||
}
|
||||
@@ -254,10 +253,7 @@ impl TextContent {
|
||||
for content in contents {
|
||||
let label = build_source_label(&content);
|
||||
labels.insert(content.id.clone(), label.clone());
|
||||
labels.insert(
|
||||
format!("{}:{}", Self::table_name(), content.id),
|
||||
label,
|
||||
);
|
||||
labels.insert(format!("{}:{}", Self::table_name(), content.id), label);
|
||||
}
|
||||
|
||||
Ok(labels)
|
||||
|
||||
@@ -8,8 +8,7 @@ use crate::storage::{
|
||||
db::SurrealDbClient,
|
||||
indexes::{ensure_runtime, rebuild},
|
||||
types::{
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
||||
system_settings::SystemSettings,
|
||||
knowledge_entity_embedding::KnowledgeEntityEmbedding, system_settings::SystemSettings,
|
||||
text_chunk_embedding::TextChunkEmbedding,
|
||||
},
|
||||
};
|
||||
@@ -27,9 +26,7 @@ pub async fn setup_test_db() -> Result<SurrealDbClient> {
|
||||
.await
|
||||
.context("start in-memory surrealdb")?;
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.context("apply migrations")?;
|
||||
db.apply_migrations().await.context("apply migrations")?;
|
||||
|
||||
Ok(db)
|
||||
}
|
||||
|
||||
@@ -59,9 +59,7 @@ async fn run_fastembed(
|
||||
texts: Vec<String>,
|
||||
) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
match tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
||||
let mut guard = model
|
||||
.lock()
|
||||
.map_err(EmbeddingError::mutex_poisoned)?;
|
||||
let mut guard = model.lock().map_err(EmbeddingError::mutex_poisoned)?;
|
||||
guard.embed(texts, None).map_err(EmbeddingError::fastembed)
|
||||
})
|
||||
.await
|
||||
@@ -215,21 +213,22 @@ impl EmbeddingProvider {
|
||||
let model_name_for_task = model_name.clone();
|
||||
let model_name_code = model_name.to_string();
|
||||
|
||||
let (model, dimension) = match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||
let model =
|
||||
TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
let info = EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
EmbeddingError::Config(format!(
|
||||
"fastembed model metadata missing for {model_name_code}"
|
||||
))
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(result) => result?,
|
||||
Err(join_error) => return Err(EmbeddingError::from(join_error)),
|
||||
};
|
||||
let (model, dimension) =
|
||||
match tokio::task::spawn_blocking(move || -> Result<_, EmbeddingError> {
|
||||
let model = TextEmbedding::try_new(options).map_err(EmbeddingError::fastembed)?;
|
||||
let info =
|
||||
EmbeddingModel::get_model_info(&model_name_for_task).ok_or_else(|| {
|
||||
EmbeddingError::Config(format!(
|
||||
"fastembed model metadata missing for {model_name_code}"
|
||||
))
|
||||
})?;
|
||||
Ok((model, info.dim))
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(result) => result?,
|
||||
Err(join_error) => return Err(EmbeddingError::from(join_error)),
|
||||
};
|
||||
|
||||
Ok(EmbeddingProvider {
|
||||
inner: EmbeddingInner::FastEmbed {
|
||||
@@ -440,10 +439,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn embedding_backend_defaults_to_fastembed() {
|
||||
assert_eq!(
|
||||
EmbeddingBackend::default(),
|
||||
EmbeddingBackend::FastEmbed
|
||||
);
|
||||
assert_eq!(EmbeddingBackend::default(), EmbeddingBackend::FastEmbed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -472,7 +468,9 @@ mod tests {
|
||||
#[test]
|
||||
fn embedding_backend_from_str_accepts_aliases() {
|
||||
assert_eq!(
|
||||
"fast-embed".parse::<EmbeddingBackend>().expect("fast-embed"),
|
||||
"fast-embed"
|
||||
.parse::<EmbeddingBackend>()
|
||||
.expect("fast-embed"),
|
||||
EmbeddingBackend::FastEmbed
|
||||
);
|
||||
assert_eq!(
|
||||
|
||||
Reference in New Issue
Block a user