mirror of
https://github.com/perstarkse/minne.git
synced 2026-07-01 02:21:34 +02:00
feat: pool fastembed, batch embeddings, and reconcile embedding config on startup
This commit is contained in:
@@ -59,8 +59,7 @@ pub struct RetrievedEntity {
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
embedding_provider: Option<&common::utils::embedding::EmbeddingProvider>,
|
||||
embedding_provider: &common::utils::embedding::EmbeddingProvider,
|
||||
input_text: &str,
|
||||
user_id: &str,
|
||||
config: RetrievalConfig,
|
||||
@@ -68,7 +67,6 @@ pub async fn retrieve(
|
||||
) -> Result<RetrievalOutput, AppError> {
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client,
|
||||
openai_client,
|
||||
embedding_provider,
|
||||
input_text,
|
||||
user_id,
|
||||
@@ -82,12 +80,16 @@ pub async fn retrieve(
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::{self};
|
||||
use async_openai::Client;
|
||||
use common::storage::indexes::ensure_runtime;
|
||||
use common::storage::types::knowledge_entity::{KnowledgeEntity, KnowledgeEntityType};
|
||||
use common::storage::types::system_settings::SystemSettings;
|
||||
use common::utils::embedding::EmbeddingProvider;
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding_provider() -> EmbeddingProvider {
|
||||
EmbeddingProvider::new_hashed(3).unwrap_or_else(|_| unreachable!())
|
||||
}
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
vec![0.9, 0.1, 0.0]
|
||||
}
|
||||
@@ -135,11 +137,10 @@ mod tests {
|
||||
|
||||
TextChunk::store_with_embedding(chunk.clone(), chunk_embedding_primary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let embedding_provider = test_embedding_provider();
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
embedding_provider: &embedding_provider,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
config: RetrievalConfig::default(),
|
||||
@@ -181,11 +182,10 @@ mod tests {
|
||||
TextChunk::store_with_embedding(primary_chunk, chunk_embedding_primary(), &db).await?;
|
||||
TextChunk::store_with_embedding(secondary_chunk, chunk_embedding_secondary(), &db).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let embedding_provider = test_embedding_provider();
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
embedding_provider: &embedding_provider,
|
||||
input_text: "Rust concurrency async tasks",
|
||||
user_id,
|
||||
config: RetrievalConfig::default(),
|
||||
@@ -236,11 +236,10 @@ mod tests {
|
||||
);
|
||||
db.store_item(entity).await?;
|
||||
|
||||
let openai_client = Client::new();
|
||||
let embedding_provider = test_embedding_provider();
|
||||
let params = pipeline::RetrievalParams {
|
||||
db_client: &db,
|
||||
openai_client: &openai_client,
|
||||
embedding_provider: None,
|
||||
embedding_provider: &embedding_provider,
|
||||
input_text: "async rust programming",
|
||||
user_id,
|
||||
config: RetrievalConfig::with_entities(),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use async_openai::Client;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::text_chunk::TextChunk},
|
||||
@@ -18,8 +17,7 @@ use super::{
|
||||
/// Mutable working state threaded through every retrieval stage.
|
||||
pub(crate) struct PipelineContext<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a EmbeddingProvider>,
|
||||
pub embedding_provider: &'a EmbeddingProvider,
|
||||
pub input_text: String,
|
||||
pub user_id: String,
|
||||
pub config: RetrievalConfig,
|
||||
@@ -36,7 +34,6 @@ impl<'a> PipelineContext<'a> {
|
||||
pub fn new(params: RetrievalParams<'a>) -> Self {
|
||||
Self {
|
||||
db_client: params.db_client,
|
||||
openai_client: params.openai_client,
|
||||
embedding_provider: params.embedding_provider,
|
||||
input_text: params.input_text.to_owned(),
|
||||
user_id: params.user_id.to_owned(),
|
||||
|
||||
@@ -7,7 +7,6 @@ pub use config::{RetrievalConfig, RetrievalTuning};
|
||||
pub use diagnostics::Diagnostics;
|
||||
|
||||
use crate::{round_score, RetrievalOutput, RetrievedEntity};
|
||||
use async_openai::Client;
|
||||
use async_trait::async_trait;
|
||||
use common::{error::AppError, storage::db::SurrealDbClient};
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -91,8 +90,7 @@ pub struct RunOutput<T> {
|
||||
/// Inputs required to run a retrieval.
|
||||
pub struct RetrievalParams<'a> {
|
||||
pub db_client: &'a SurrealDbClient,
|
||||
pub openai_client: &'a Client<async_openai::config::OpenAIConfig>,
|
||||
pub embedding_provider: Option<&'a common::utils::embedding::EmbeddingProvider>,
|
||||
pub embedding_provider: &'a common::utils::embedding::EmbeddingProvider,
|
||||
pub input_text: &'a str,
|
||||
pub user_id: &'a str,
|
||||
pub config: RetrievalConfig,
|
||||
|
||||
@@ -2,7 +2,6 @@ use async_trait::async_trait;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use fastembed::RerankResult;
|
||||
use std::collections::HashMap;
|
||||
@@ -97,11 +96,7 @@ pub async fn embed(ctx: &mut PipelineContext<'_>) -> Result<(), AppError> {
|
||||
debug!("Reusing cached query embedding for hybrid retrieval");
|
||||
} else {
|
||||
debug!("Generating query embedding for hybrid retrieval");
|
||||
let embedding = if let Some(provider) = ctx.embedding_provider {
|
||||
provider.embed(&ctx.input_text).await?
|
||||
} else {
|
||||
generate_embedding(ctx.openai_client, &ctx.input_text, ctx.db_client).await?
|
||||
};
|
||||
let embedding = ctx.embedding_provider.embed(&ctx.input_text).await?;
|
||||
ctx.query_embedding = Some(embedding);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user