mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-02 11:40:01 +01:00
Compare commits
7 Commits
v0.2.2
...
hybrid-ret
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb7f625b81 | ||
|
|
dc40cf7663 | ||
|
|
aa0b1462a1 | ||
|
|
41fc7bb99c | ||
|
|
61d8d7abe7 | ||
|
|
b7344644dc | ||
|
|
3742598a6d |
@@ -1,6 +1,9 @@
|
||||
# Changelog
|
||||
## Unreleased
|
||||
|
||||
## Version 0.2.3 (2025-10-12)
|
||||
- Fix changing vector dimensions on a fresh database (#3)
|
||||
|
||||
## Version 0.2.2 (2025-10-07)
|
||||
- Support for ingestion of PDF files
|
||||
- Improved ingestion speed
|
||||
|
||||
31
Cargo.lock
generated
31
Cargo.lock
generated
@@ -1322,6 +1322,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"state-machines",
|
||||
"surrealdb",
|
||||
"surrealdb-migrations",
|
||||
"tempfile",
|
||||
@@ -3291,7 +3292,7 @@ checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4"
|
||||
|
||||
[[package]]
|
||||
name = "main"
|
||||
version = "0.2.1"
|
||||
version = "0.2.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"api-router",
|
||||
@@ -5400,6 +5401,34 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "state-machines"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "806ba0bf43ae158b229036d8a84601649a58d9761e718b5e0e07c2953803f4c1"
|
||||
dependencies = [
|
||||
"state-machines-core",
|
||||
"state-machines-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "state-machines-core"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "949cc50e84bed6234117f28a0ba2980dc35e9c17984ffe4e0a3364fba3e77540"
|
||||
|
||||
[[package]]
|
||||
name = "state-machines-macro"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8322f5aa92d31b3c05faa1ec3231b82da479a20706836867d67ae89ce74927bd"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"state-machines-core",
|
||||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions_next"
|
||||
version = "1.1.2"
|
||||
|
||||
@@ -55,6 +55,7 @@ tokio-retry = "0.3.0"
|
||||
base64 = "0.22.1"
|
||||
object_store = { version = "0.11.2" }
|
||||
bytes = "1.7.1"
|
||||
state-machines = "0.2.0"
|
||||
|
||||
[profile.dist]
|
||||
inherits = "release"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
[](https://www.gnu.org/licenses/agpl-3.0)
|
||||
[](https://github.com/perstarkse/minne/releases/latest)
|
||||
|
||||

|
||||

|
||||
|
||||
## Demo deployment
|
||||
|
||||
@@ -28,6 +28,8 @@ You may switch and choose between models used, and have the possiblity to change
|
||||
|
||||
The application is built for speed and efficiency using Rust with a Server-Side Rendered (SSR) frontend (HTMX and minimal JavaScript). It's fully responsive, offering a complete mobile interface for reading, editing, and managing your content, including the graph database itself. **PWA (Progressive Web App) support** means you can "install" Minne to your device for a native-like experience. For quick capture on the go on iOS, a [**Shortcut**](https://www.icloud.com/shortcuts/e433fbd7602f4e2eaa70dca162323477) makes sending content to your Minne instance a breeze.
|
||||
|
||||
A hybrid retrieval layer blends embeddings, full-text search, and graph signals to surface the best context when augmenting chat responses and when building new relationships during ingestion.
|
||||
|
||||
Minne is open source (AGPL), self-hostable, and can be deployed flexibly: via Nix, Docker Compose, pre-built binaries, or by building from source. It can run as a single `main` binary or as separate `server` and `worker` processes for optimized resource allocation.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
@@ -41,6 +41,7 @@ surrealdb-migrations = { workspace = true }
|
||||
tokio-retry = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
state-machines = { workspace = true }
|
||||
|
||||
|
||||
[features]
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
-- Add FTS indexes for searching name and description on entities
|
||||
|
||||
DEFINE ANALYZER IF NOT EXISTS app_en_fts_analyzer
|
||||
TOKENIZERS class
|
||||
FILTERS lowercase, ascii, snowball(english);
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_name_idx ON TABLE knowledge_entity
|
||||
FIELDS name
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS knowledge_entity_fts_description_idx ON TABLE knowledge_entity
|
||||
FIELDS description
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS text_chunk_fts_chunk_idx ON TABLE text_chunk
|
||||
FIELDS chunk
|
||||
SEARCH ANALYZER app_en_fts_analyzer BM25;
|
||||
173
common/migrations/20251012_205900_state_machine_migration.surql
Normal file
173
common/migrations/20251012_205900_state_machine_migration.surql
Normal file
@@ -0,0 +1,173 @@
|
||||
-- State machine migration for ingestion_task records
|
||||
|
||||
DEFINE FIELD IF NOT EXISTS state ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS attempts ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS max_attempts ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS scheduled_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS locked_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS lease_duration_secs ON TABLE ingestion_task TYPE option<number>;
|
||||
DEFINE FIELD IF NOT EXISTS worker_id ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_code ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS error_message ON TABLE ingestion_task TYPE option<string>;
|
||||
DEFINE FIELD IF NOT EXISTS last_error_at ON TABLE ingestion_task TYPE option<datetime>;
|
||||
DEFINE FIELD IF NOT EXISTS priority ON TABLE ingestion_task TYPE option<number>;
|
||||
|
||||
REMOVE FIELD status ON TABLE ingestion_task;
|
||||
DEFINE FIELD status ON TABLE ingestion_task TYPE option<object>;
|
||||
|
||||
DEFINE INDEX IF NOT EXISTS idx_ingestion_task_state_sched ON TABLE ingestion_task FIELDS state, scheduled_at;
|
||||
|
||||
LET $needs_migration = (SELECT count() AS count FROM type::table('ingestion_task') WHERE state = NONE)[0].count;
|
||||
|
||||
IF $needs_migration > 0 THEN {
|
||||
-- Created -> Pending
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Pending",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF created_at != NONE THEN created_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Created";
|
||||
|
||||
-- InProgress -> Processing
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Processing",
|
||||
attempts = IF status.attempts != NONE THEN status.attempts ELSE 1 END,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||
locked_at = IF status.last_attempt != NONE THEN status.last_attempt ELSE time::now() END,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "InProgress";
|
||||
|
||||
-- Completed -> Succeeded
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Succeeded",
|
||||
attempts = 1,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Completed";
|
||||
|
||||
-- Error -> DeadLetter (terminal failure)
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "DeadLetter",
|
||||
attempts = 3,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = status.message,
|
||||
last_error_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Error";
|
||||
|
||||
-- Cancelled -> Cancelled
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Cancelled",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE
|
||||
AND status != NONE
|
||||
AND status.name = "Cancelled";
|
||||
|
||||
-- Fallback for any remaining records missing state
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET
|
||||
state = "Pending",
|
||||
attempts = 0,
|
||||
max_attempts = 3,
|
||||
scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END,
|
||||
locked_at = NONE,
|
||||
lease_duration_secs = 300,
|
||||
worker_id = NONE,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE,
|
||||
priority = 0
|
||||
WHERE state = NONE;
|
||||
} END;
|
||||
|
||||
-- Ensure defaults for newly added fields
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET max_attempts = 3
|
||||
WHERE max_attempts = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET lease_duration_secs = 300
|
||||
WHERE lease_duration_secs = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET attempts = 0
|
||||
WHERE attempts = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET priority = 0
|
||||
WHERE priority = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET scheduled_at = IF updated_at != NONE THEN updated_at ELSE time::now() END
|
||||
WHERE scheduled_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET locked_at = NONE
|
||||
WHERE locked_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET worker_id = NONE
|
||||
WHERE worker_id != NONE AND worker_id = "";
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET error_code = NONE
|
||||
WHERE error_code = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET error_message = NONE
|
||||
WHERE error_message = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET last_error_at = NONE
|
||||
WHERE last_error_at = NONE;
|
||||
|
||||
UPDATE type::table('ingestion_task')
|
||||
SET status = NONE
|
||||
WHERE status != NONE;
|
||||
@@ -80,15 +80,18 @@ impl SurrealDbClient {
|
||||
/// Operation to rebuild indexes
|
||||
pub async fn rebuild_indexes(&self) -> Result<(), Error> {
|
||||
debug!("Rebuilding indexes");
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk")
|
||||
.await?;
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity")
|
||||
.await?;
|
||||
self.client
|
||||
.query("REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content")
|
||||
.await?;
|
||||
let rebuild_sql = r#"
|
||||
BEGIN TRANSACTION;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_chunks ON text_chunk;
|
||||
REBUILD INDEX IF EXISTS idx_embedding_entities ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS text_content_fts_idx ON text_content;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_name_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS knowledge_entity_fts_description_idx ON knowledge_entity;
|
||||
REBUILD INDEX IF EXISTS text_chunk_fts_chunk_idx ON text_chunk;
|
||||
COMMIT TRANSACTION;
|
||||
"#;
|
||||
|
||||
self.client.query(rebuild_sql).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,116 +1,529 @@
|
||||
use futures::Stream;
|
||||
use surrealdb::{opt::PatchOp, Notification};
|
||||
use std::time::Duration;
|
||||
|
||||
use chrono::Duration as ChronoDuration;
|
||||
use state_machines::state_machine;
|
||||
use surrealdb::sql::Datetime as SurrealDatetime;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{error::AppError, storage::db::SurrealDbClient, stored_object};
|
||||
|
||||
use super::ingestion_payload::IngestionPayload;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "name")]
|
||||
pub enum IngestionTaskStatus {
|
||||
Created,
|
||||
InProgress {
|
||||
attempts: u32,
|
||||
last_attempt: DateTime<Utc>,
|
||||
},
|
||||
Completed,
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
pub const MAX_ATTEMPTS: u32 = 3;
|
||||
pub const DEFAULT_LEASE_SECS: i64 = 300;
|
||||
pub const DEFAULT_PRIORITY: i32 = 0;
|
||||
|
||||
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
|
||||
pub enum TaskState {
|
||||
#[serde(rename = "Pending")]
|
||||
#[default]
|
||||
Pending,
|
||||
#[serde(rename = "Reserved")]
|
||||
Reserved,
|
||||
#[serde(rename = "Processing")]
|
||||
Processing,
|
||||
#[serde(rename = "Succeeded")]
|
||||
Succeeded,
|
||||
#[serde(rename = "Failed")]
|
||||
Failed,
|
||||
#[serde(rename = "Cancelled")]
|
||||
Cancelled,
|
||||
#[serde(rename = "DeadLetter")]
|
||||
DeadLetter,
|
||||
}
|
||||
|
||||
impl TaskState {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
TaskState::Pending => "Pending",
|
||||
TaskState::Reserved => "Reserved",
|
||||
TaskState::Processing => "Processing",
|
||||
TaskState::Succeeded => "Succeeded",
|
||||
TaskState::Failed => "Failed",
|
||||
TaskState::Cancelled => "Cancelled",
|
||||
TaskState::DeadLetter => "DeadLetter",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_terminal(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
TaskState::Succeeded | TaskState::Cancelled | TaskState::DeadLetter
|
||||
)
|
||||
}
|
||||
|
||||
pub fn display_label(&self) -> &'static str {
|
||||
match self {
|
||||
TaskState::Pending => "Pending",
|
||||
TaskState::Reserved => "Reserved",
|
||||
TaskState::Processing => "Processing",
|
||||
TaskState::Succeeded => "Completed",
|
||||
TaskState::Failed => "Retrying",
|
||||
TaskState::Cancelled => "Cancelled",
|
||||
TaskState::DeadLetter => "Dead Letter",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Default)]
|
||||
pub struct TaskErrorInfo {
|
||||
pub code: Option<String>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum TaskTransition {
|
||||
StartProcessing,
|
||||
Succeed,
|
||||
Fail,
|
||||
Cancel,
|
||||
DeadLetter,
|
||||
Release,
|
||||
}
|
||||
|
||||
impl TaskTransition {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
TaskTransition::StartProcessing => "start_processing",
|
||||
TaskTransition::Succeed => "succeed",
|
||||
TaskTransition::Fail => "fail",
|
||||
TaskTransition::Cancel => "cancel",
|
||||
TaskTransition::DeadLetter => "deadletter",
|
||||
TaskTransition::Release => "release",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod lifecycle {
|
||||
use super::state_machine;
|
||||
|
||||
state_machine! {
|
||||
name: TaskLifecycleMachine,
|
||||
initial: Pending,
|
||||
states: [Pending, Reserved, Processing, Succeeded, Failed, Cancelled, DeadLetter],
|
||||
events {
|
||||
reserve {
|
||||
transition: { from: Pending, to: Reserved }
|
||||
transition: { from: Failed, to: Reserved }
|
||||
}
|
||||
start_processing {
|
||||
transition: { from: Reserved, to: Processing }
|
||||
}
|
||||
succeed {
|
||||
transition: { from: Processing, to: Succeeded }
|
||||
}
|
||||
fail {
|
||||
transition: { from: Processing, to: Failed }
|
||||
}
|
||||
cancel {
|
||||
transition: { from: Pending, to: Cancelled }
|
||||
transition: { from: Reserved, to: Cancelled }
|
||||
transition: { from: Processing, to: Cancelled }
|
||||
}
|
||||
deadletter {
|
||||
transition: { from: Failed, to: DeadLetter }
|
||||
}
|
||||
release {
|
||||
transition: { from: Reserved, to: Pending }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn pending() -> TaskLifecycleMachine<(), Pending> {
|
||||
TaskLifecycleMachine::new(())
|
||||
}
|
||||
|
||||
pub(super) fn reserved() -> TaskLifecycleMachine<(), Reserved> {
|
||||
pending()
|
||||
.reserve()
|
||||
.expect("reserve transition from Pending should exist")
|
||||
}
|
||||
|
||||
pub(super) fn processing() -> TaskLifecycleMachine<(), Processing> {
|
||||
reserved()
|
||||
.start_processing()
|
||||
.expect("start_processing transition from Reserved should exist")
|
||||
}
|
||||
|
||||
pub(super) fn failed() -> TaskLifecycleMachine<(), Failed> {
|
||||
processing()
|
||||
.fail()
|
||||
.expect("fail transition from Processing should exist")
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_transition(state: &TaskState, event: TaskTransition) -> AppError {
|
||||
AppError::Validation(format!(
|
||||
"Invalid task transition: {} -> {}",
|
||||
state.as_str(),
|
||||
event.as_str()
|
||||
))
|
||||
}
|
||||
|
||||
stored_object!(IngestionTask, "ingestion_task", {
|
||||
content: IngestionPayload,
|
||||
status: IngestionTaskStatus,
|
||||
user_id: String
|
||||
state: TaskState,
|
||||
user_id: String,
|
||||
attempts: u32,
|
||||
max_attempts: u32,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime")]
|
||||
scheduled_at: chrono::DateTime<chrono::Utc>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
locked_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
lease_duration_secs: i64,
|
||||
worker_id: Option<String>,
|
||||
error_code: Option<String>,
|
||||
error_message: Option<String>,
|
||||
#[serde(
|
||||
serialize_with = "serialize_option_datetime",
|
||||
deserialize_with = "deserialize_option_datetime",
|
||||
default
|
||||
)]
|
||||
last_error_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
priority: i32
|
||||
});
|
||||
|
||||
pub const MAX_ATTEMPTS: u32 = 3;
|
||||
|
||||
impl IngestionTask {
|
||||
pub async fn new(content: IngestionPayload, user_id: String) -> Self {
|
||||
let now = Utc::now();
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
content,
|
||||
status: IngestionTaskStatus::Created,
|
||||
state: TaskState::Pending,
|
||||
user_id,
|
||||
attempts: 0,
|
||||
max_attempts: MAX_ATTEMPTS,
|
||||
scheduled_at: now,
|
||||
locked_at: None,
|
||||
lease_duration_secs: DEFAULT_LEASE_SECS,
|
||||
worker_id: None,
|
||||
error_code: None,
|
||||
error_message: None,
|
||||
last_error_at: None,
|
||||
priority: DEFAULT_PRIORITY,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
user_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new job and stores it in the database
|
||||
pub fn can_retry(&self) -> bool {
|
||||
self.attempts < self.max_attempts
|
||||
}
|
||||
|
||||
pub fn lease_duration(&self) -> Duration {
|
||||
Duration::from_secs(self.lease_duration_secs.max(0) as u64)
|
||||
}
|
||||
|
||||
pub async fn create_and_add_to_db(
|
||||
content: IngestionPayload,
|
||||
user_id: String,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
let task = Self::new(content, user_id).await;
|
||||
|
||||
db.store_item(task.clone()).await?;
|
||||
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
// Update job status
|
||||
pub async fn update_status(
|
||||
id: &str,
|
||||
status: IngestionTaskStatus,
|
||||
pub async fn claim_next_ready(
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<(), AppError> {
|
||||
let _job: Option<Self> = db
|
||||
.update((Self::table_name(), id))
|
||||
.patch(PatchOp::replace("/status", status))
|
||||
.patch(PatchOp::replace(
|
||||
"/updated_at",
|
||||
surrealdb::Datetime::from(Utc::now()),
|
||||
worker_id: &str,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
lease_duration: Duration,
|
||||
) -> Result<Option<IngestionTask>, AppError> {
|
||||
debug_assert!(lifecycle::pending().reserve().is_ok());
|
||||
debug_assert!(lifecycle::failed().reserve().is_ok());
|
||||
|
||||
const CLAIM_QUERY: &str = r#"
|
||||
UPDATE (
|
||||
SELECT * FROM type::table($table)
|
||||
WHERE state IN $candidate_states
|
||||
AND scheduled_at <= $now
|
||||
AND (
|
||||
attempts < max_attempts
|
||||
OR state IN $sticky_states
|
||||
)
|
||||
AND (
|
||||
locked_at = NONE
|
||||
OR time::unix($now) - time::unix(locked_at) >= lease_duration_secs
|
||||
)
|
||||
ORDER BY priority DESC, scheduled_at ASC, created_at ASC
|
||||
LIMIT 1
|
||||
)
|
||||
SET state = $reserved_state,
|
||||
attempts = if state IN $increment_states THEN
|
||||
if attempts + 1 > max_attempts THEN max_attempts ELSE attempts + 1 END
|
||||
ELSE
|
||||
attempts
|
||||
END,
|
||||
locked_at = $now,
|
||||
worker_id = $worker_id,
|
||||
lease_duration_secs = $lease_secs,
|
||||
updated_at = $now
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(CLAIM_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind((
|
||||
"candidate_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Failed.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind((
|
||||
"sticky_states",
|
||||
vec![TaskState::Reserved.as_str(), TaskState::Processing.as_str()],
|
||||
))
|
||||
.bind((
|
||||
"increment_states",
|
||||
vec![TaskState::Pending.as_str(), TaskState::Failed.as_str()],
|
||||
))
|
||||
.bind(("reserved_state", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", worker_id.to_string()))
|
||||
.bind(("lease_secs", lease_duration.as_secs() as i64))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
let task: Option<IngestionTask> = result.take(0)?;
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
/// Listen for new jobs
|
||||
pub async fn listen_for_tasks(
|
||||
pub async fn mark_processing(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const START_PROCESSING_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $processing,
|
||||
updated_at = $now,
|
||||
locked_at = $now
|
||||
WHERE state = $reserved AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(START_PROCESSING_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("reserved", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::StartProcessing))
|
||||
}
|
||||
|
||||
pub async fn mark_succeeded(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const COMPLETE_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $succeeded,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $now,
|
||||
error_code = NONE,
|
||||
error_message = NONE,
|
||||
last_error_at = NONE
|
||||
WHERE state = $processing AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(COMPLETE_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("succeeded", TaskState::Succeeded.as_str()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Succeed))
|
||||
}
|
||||
|
||||
pub async fn mark_failed(
|
||||
&self,
|
||||
error: TaskErrorInfo,
|
||||
retry_delay: Duration,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<impl Stream<Item = Result<Notification<Self>, surrealdb::Error>>, surrealdb::Error>
|
||||
{
|
||||
db.listen::<Self>().await
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
let now = chrono::Utc::now();
|
||||
let retry_at = now
|
||||
+ ChronoDuration::from_std(retry_delay).unwrap_or_else(|_| ChronoDuration::seconds(30));
|
||||
|
||||
const FAIL_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $failed,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $retry_at,
|
||||
error_code = $error_code,
|
||||
error_message = $error_message,
|
||||
last_error_at = $now
|
||||
WHERE state = $processing AND worker_id = $worker_id
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let mut result = db
|
||||
.client
|
||||
.query(FAIL_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("failed", TaskState::Failed.as_str()))
|
||||
.bind(("processing", TaskState::Processing.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("retry_at", SurrealDatetime::from(retry_at)))
|
||||
.bind(("error_code", error.code.clone()))
|
||||
.bind(("error_message", error.message.clone()))
|
||||
.bind(("worker_id", self.worker_id.clone().unwrap_or_default()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Fail))
|
||||
}
|
||||
|
||||
/// Get all unfinished tasks, ie newly created and in progress up two times
|
||||
pub async fn get_unfinished_tasks(db: &SurrealDbClient) -> Result<Vec<Self>, AppError> {
|
||||
let jobs: Vec<Self> = db
|
||||
pub async fn mark_dead_letter(
|
||||
&self,
|
||||
error: TaskErrorInfo,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<IngestionTask, AppError> {
|
||||
const DEAD_LETTER_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $dead,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE,
|
||||
scheduled_at = $now,
|
||||
error_code = $error_code,
|
||||
error_message = $error_message,
|
||||
last_error_at = $now
|
||||
WHERE state = $failed
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(DEAD_LETTER_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("dead", TaskState::DeadLetter.as_str()))
|
||||
.bind(("failed", TaskState::Failed.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.bind(("error_code", error.code.clone()))
|
||||
.bind(("error_message", error.message.clone()))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::DeadLetter))
|
||||
}
|
||||
|
||||
pub async fn mark_cancelled(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const CANCEL_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $cancelled,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE
|
||||
WHERE state IN $allow_states
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(CANCEL_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("cancelled", TaskState::Cancelled.as_str()))
|
||||
.bind((
|
||||
"allow_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Cancel))
|
||||
}
|
||||
|
||||
pub async fn release(&self, db: &SurrealDbClient) -> Result<IngestionTask, AppError> {
|
||||
const RELEASE_QUERY: &str = r#"
|
||||
UPDATE type::thing($table, $id)
|
||||
SET state = $pending,
|
||||
updated_at = $now,
|
||||
locked_at = NONE,
|
||||
worker_id = NONE
|
||||
WHERE state = $reserved
|
||||
RETURN *;
|
||||
"#;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let mut result = db
|
||||
.client
|
||||
.query(RELEASE_QUERY)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("id", self.id.clone()))
|
||||
.bind(("pending", TaskState::Pending.as_str()))
|
||||
.bind(("reserved", TaskState::Reserved.as_str()))
|
||||
.bind(("now", SurrealDatetime::from(now)))
|
||||
.await?;
|
||||
|
||||
let updated: Option<IngestionTask> = result.take(0)?;
|
||||
updated.ok_or_else(|| invalid_transition(&self.state, TaskTransition::Release))
|
||||
}
|
||||
|
||||
pub async fn get_unfinished_tasks(
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<IngestionTask>, AppError> {
|
||||
let tasks: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE
|
||||
status.name = 'Created'
|
||||
OR (
|
||||
status.name = 'InProgress'
|
||||
AND status.attempts < $max_attempts
|
||||
)
|
||||
ORDER BY created_at ASC",
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE state IN $active_states
|
||||
ORDER BY scheduled_at ASC, created_at ASC",
|
||||
)
|
||||
.bind(("table", Self::table_name()))
|
||||
.bind(("max_attempts", MAX_ATTEMPTS))
|
||||
.bind((
|
||||
"active_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
TaskState::Failed.as_str(),
|
||||
],
|
||||
))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(jobs)
|
||||
Ok(tasks)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
|
||||
// Helper function to create a test ingestion payload
|
||||
fn create_test_payload(user_id: &str) -> IngestionPayload {
|
||||
fn create_payload(user_id: &str) -> IngestionPayload {
|
||||
IngestionPayload::Text {
|
||||
text: "Test content".to_string(),
|
||||
context: "Test context".to_string(),
|
||||
@@ -119,182 +532,197 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_ingestion_task() {
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
async fn memory_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = Uuid::new_v4().to_string();
|
||||
SurrealDbClient::memory(namespace, &database)
|
||||
.await
|
||||
.expect("in-memory surrealdb")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_new_task_defaults() {
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
|
||||
// Verify task properties
|
||||
assert_eq!(task.user_id, user_id);
|
||||
assert_eq!(task.content, payload);
|
||||
assert!(matches!(task.status, IngestionTaskStatus::Created));
|
||||
assert!(!task.id.is_empty());
|
||||
assert_eq!(task.state, TaskState::Pending);
|
||||
assert_eq!(task.attempts, 0);
|
||||
assert_eq!(task.max_attempts, MAX_ATTEMPTS);
|
||||
assert!(task.locked_at.is_none());
|
||||
assert!(task.worker_id.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_and_add_to_db() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_create_and_store_task() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
// Create and store task
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
let created =
|
||||
IngestionTask::create_and_add_to_db(payload.clone(), user_id.to_string(), &db)
|
||||
.await
|
||||
.expect("store");
|
||||
|
||||
let stored: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&created.id)
|
||||
.await
|
||||
.expect("Failed to create and add task to db");
|
||||
.expect("fetch");
|
||||
|
||||
// Query to verify task was stored
|
||||
let query = format!(
|
||||
"SELECT * FROM {} WHERE user_id = '{}'",
|
||||
IngestionTask::table_name(),
|
||||
user_id
|
||||
);
|
||||
let mut result = db.query(query).await.expect("Query failed");
|
||||
let tasks: Vec<IngestionTask> = result.take(0).unwrap_or_default();
|
||||
|
||||
// Verify task is in the database
|
||||
assert!(!tasks.is_empty(), "Task should exist in the database");
|
||||
let stored_task = &tasks[0];
|
||||
assert_eq!(stored_task.user_id, user_id);
|
||||
assert!(matches!(stored_task.status, IngestionTaskStatus::Created));
|
||||
let stored = stored.expect("task exists");
|
||||
assert_eq!(stored.id, created.id);
|
||||
assert_eq!(stored.state, TaskState::Pending);
|
||||
assert_eq!(stored.attempts, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_status() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_claim_and_transition() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string()).await;
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
// Create task manually
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task_id = task.id.clone();
|
||||
let worker_id = "worker-1";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim");
|
||||
|
||||
// Store task
|
||||
db.store_item(task).await.expect("Failed to store task");
|
||||
let claimed = claimed.expect("task claimed");
|
||||
assert_eq!(claimed.state, TaskState::Reserved);
|
||||
assert_eq!(claimed.worker_id.as_deref(), Some(worker_id));
|
||||
|
||||
// Update status to InProgress
|
||||
let now = Utc::now();
|
||||
let new_status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: now,
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
assert_eq!(processing.state, TaskState::Processing);
|
||||
|
||||
let succeeded = processing.mark_succeeded(&db).await.expect("succeeded");
|
||||
assert_eq!(succeeded.state, TaskState::Succeeded);
|
||||
assert!(succeeded.worker_id.is_none());
|
||||
assert!(succeeded.locked_at.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fail_and_dead_letter() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
let task = IngestionTask::new(payload, user_id.to_string()).await;
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let worker_id = "worker-dead";
|
||||
let now = chrono::Utc::now();
|
||||
let claimed = IngestionTask::claim_next_ready(&db, worker_id, now, Duration::from_secs(60))
|
||||
.await
|
||||
.expect("claim")
|
||||
.expect("claimed");
|
||||
|
||||
let processing = claimed.mark_processing(&db).await.expect("processing");
|
||||
|
||||
let error_info = TaskErrorInfo {
|
||||
code: Some("pipeline_error".into()),
|
||||
message: "failed".into(),
|
||||
};
|
||||
|
||||
IngestionTask::update_status(&task_id, new_status.clone(), &db)
|
||||
let failed = processing
|
||||
.mark_failed(error_info.clone(), Duration::from_secs(30), &db)
|
||||
.await
|
||||
.expect("Failed to update status");
|
||||
.expect("failed update");
|
||||
assert_eq!(failed.state, TaskState::Failed);
|
||||
assert_eq!(failed.error_message.as_deref(), Some("failed"));
|
||||
assert!(failed.worker_id.is_none());
|
||||
assert!(failed.locked_at.is_none());
|
||||
assert!(failed.scheduled_at > now);
|
||||
|
||||
// Verify status updated
|
||||
let updated_task: Option<IngestionTask> = db
|
||||
.get_item::<IngestionTask>(&task_id)
|
||||
let dead = failed
|
||||
.mark_dead_letter(error_info.clone(), &db)
|
||||
.await
|
||||
.expect("Failed to get updated task");
|
||||
.expect("dead letter");
|
||||
assert_eq!(dead.state, TaskState::DeadLetter);
|
||||
assert_eq!(dead.error_message.as_deref(), Some("failed"));
|
||||
}
|
||||
|
||||
assert!(updated_task.is_some());
|
||||
let updated_task = updated_task.unwrap();
|
||||
#[tokio::test]
|
||||
async fn test_mark_processing_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
match updated_task.status {
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
assert_eq!(attempts, 1);
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let err = task
|
||||
.mark_processing(&db)
|
||||
.await
|
||||
.expect_err("processing should fail without reservation");
|
||||
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> start_processing"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected InProgress status"),
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_unfinished_tasks() {
|
||||
// Setup in-memory database
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
async fn test_mark_failed_requires_processing() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_test_payload(user_id);
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
// Create tasks with different statuses
|
||||
let created_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let mut in_progress_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
in_progress_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut max_attempts_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
max_attempts_task.status = IngestionTaskStatus::InProgress {
|
||||
attempts: MAX_ATTEMPTS,
|
||||
last_attempt: Utc::now(),
|
||||
};
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
completed_task.status = IngestionTaskStatus::Completed;
|
||||
|
||||
let mut error_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
error_task.status = IngestionTaskStatus::Error {
|
||||
message: "Test error".to_string(),
|
||||
};
|
||||
|
||||
// Store all tasks
|
||||
db.store_item(created_task)
|
||||
let err = task
|
||||
.mark_failed(
|
||||
TaskErrorInfo {
|
||||
code: None,
|
||||
message: "boom".into(),
|
||||
},
|
||||
Duration::from_secs(30),
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
db.store_item(in_progress_task)
|
||||
.await
|
||||
.expect("Failed to store in-progress task");
|
||||
db.store_item(max_attempts_task)
|
||||
.await
|
||||
.expect("Failed to store max-attempts task");
|
||||
db.store_item(completed_task)
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
db.store_item(error_task)
|
||||
.await
|
||||
.expect("Failed to store error task");
|
||||
.expect_err("failing should require processing state");
|
||||
|
||||
// Get unfinished tasks
|
||||
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db)
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> fail"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_release_requires_reservation() {
|
||||
let db = memory_db().await;
|
||||
let user_id = "user123";
|
||||
let payload = create_payload(user_id);
|
||||
|
||||
let task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
db.store_item(task.clone()).await.expect("store");
|
||||
|
||||
let err = task
|
||||
.release(&db)
|
||||
.await
|
||||
.expect("Failed to get unfinished tasks");
|
||||
.expect_err("release should require reserved state");
|
||||
|
||||
// Verify only Created and InProgress with attempts < MAX_ATTEMPTS are returned
|
||||
assert_eq!(unfinished_tasks.len(), 2);
|
||||
|
||||
let statuses: Vec<_> = unfinished_tasks
|
||||
.iter()
|
||||
.map(|task| match &task.status {
|
||||
IngestionTaskStatus::Created => "Created",
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
if *attempts < MAX_ATTEMPTS {
|
||||
"InProgress<MAX"
|
||||
} else {
|
||||
"InProgress>=MAX"
|
||||
}
|
||||
}
|
||||
IngestionTaskStatus::Completed => "Completed",
|
||||
IngestionTaskStatus::Error { .. } => "Error",
|
||||
IngestionTaskStatus::Cancelled => "Cancelled",
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert!(statuses.contains(&"Created"));
|
||||
assert!(statuses.contains(&"InProgress<MAX"));
|
||||
assert!(!statuses.contains(&"InProgress>=MAX"));
|
||||
assert!(!statuses.contains(&"Completed"));
|
||||
assert!(!statuses.contains(&"Error"));
|
||||
assert!(!statuses.contains(&"Cancelled"));
|
||||
match err {
|
||||
AppError::Validation(message) => {
|
||||
assert!(
|
||||
message.contains("Pending -> release"),
|
||||
"unexpected message: {message}"
|
||||
);
|
||||
}
|
||||
other => panic!("expected validation error, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +150,18 @@ impl KnowledgeEntity {
|
||||
let all_entities: Vec<KnowledgeEntity> = db.select(Self::table_name()).await?;
|
||||
let total_entities = all_entities.len();
|
||||
if total_entities == 0 {
|
||||
info!("No knowledge entities to update. Skipping.");
|
||||
info!("No knowledge entities to update. Just updating the idx");
|
||||
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
transaction_query
|
||||
.push_str("REMOVE INDEX idx_embedding_entities ON TABLE knowledge_entity;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions
|
||||
));
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} entities to process.", total_entities);
|
||||
|
||||
@@ -83,6 +83,32 @@ macro_rules! stored_object {
|
||||
Ok(DateTime::<Utc>::from(dt))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn serialize_option_datetime<S>(
|
||||
date: &Option<DateTime<Utc>>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match date {
|
||||
Some(dt) => serializer
|
||||
.serialize_some(&Into::<surrealdb::sql::Datetime>::into(*dt)),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn deserialize_option_datetime<'de, D>(
|
||||
deserializer: D,
|
||||
) -> Result<Option<DateTime<Utc>>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let value = Option::<surrealdb::sql::Datetime>::deserialize(deserializer)?;
|
||||
Ok(value.map(DateTime::<Utc>::from))
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct $name {
|
||||
@@ -92,7 +118,7 @@ macro_rules! stored_object {
|
||||
pub created_at: DateTime<Utc>,
|
||||
#[serde(serialize_with = "serialize_datetime", deserialize_with = "deserialize_datetime", default)]
|
||||
pub updated_at: DateTime<Utc>,
|
||||
$(pub $field: $ty),*
|
||||
$( $(#[$attr])* pub $field: $ty),*
|
||||
}
|
||||
|
||||
impl StoredObject for $name {
|
||||
|
||||
@@ -53,11 +53,60 @@ impl SystemSettings {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::storage::types::text_chunk::TextChunk;
|
||||
use crate::storage::types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk};
|
||||
use async_openai::Client;
|
||||
|
||||
use super::*;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn get_hnsw_index_dimension(
|
||||
db: &SurrealDbClient,
|
||||
table_name: &str,
|
||||
index_name: &str,
|
||||
) -> u32 {
|
||||
let query = format!("INFO FOR TABLE {table_name};");
|
||||
let mut response = db
|
||||
.client
|
||||
.query(query)
|
||||
.await
|
||||
.expect("Failed to fetch table info");
|
||||
|
||||
let info: Option<serde_json::Value> = response
|
||||
.take(0)
|
||||
.expect("Failed to extract table info response");
|
||||
|
||||
let info = info.expect("Table info result missing");
|
||||
|
||||
let indexes = info
|
||||
.get("indexes")
|
||||
.or_else(|| {
|
||||
info.get("tables")
|
||||
.and_then(|tables| tables.get(table_name))
|
||||
.and_then(|table| table.get("indexes"))
|
||||
})
|
||||
.unwrap_or_else(|| panic!("Indexes collection missing in table info: {info:#?}"));
|
||||
|
||||
let definition = indexes
|
||||
.get(index_name)
|
||||
.and_then(|definition| definition.as_str())
|
||||
.unwrap_or_else(|| panic!("Index definition not found in table info: {info:#?}"));
|
||||
|
||||
let dimension_part = definition
|
||||
.split("DIMENSION")
|
||||
.nth(1)
|
||||
.expect("Index definition missing DIMENSION clause");
|
||||
|
||||
let dimension_token = dimension_part
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.expect("Dimension value missing in definition")
|
||||
.trim_end_matches(';');
|
||||
|
||||
dimension_token
|
||||
.parse::<u32>()
|
||||
.expect("Dimension value is not a valid number")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_settings_initialization() {
|
||||
// Setup in-memory database for testing
|
||||
@@ -255,4 +304,74 @@ mod tests {
|
||||
|
||||
assert!(migration_result.is_ok(), "Migrations should not fail");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_should_change_embedding_length_on_indexes_when_switching_length() {
|
||||
let db = SurrealDbClient::memory("test", &Uuid::new_v4().to_string())
|
||||
.await
|
||||
.expect("Failed to start DB");
|
||||
|
||||
// Apply initial migrations. This sets up the text_chunk index with DIMENSION 1536.
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Initial migration failed");
|
||||
|
||||
let mut current_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to load current settings");
|
||||
|
||||
let initial_chunk_dimension =
|
||||
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
|
||||
|
||||
assert_eq!(
|
||||
initial_chunk_dimension, current_settings.embedding_dimensions,
|
||||
"embedding size should match initial system settings"
|
||||
);
|
||||
|
||||
let new_dimension = 768;
|
||||
let new_model = "new-test-embedding-model".to_string();
|
||||
|
||||
current_settings.embedding_dimensions = new_dimension;
|
||||
current_settings.embedding_model = new_model.clone();
|
||||
|
||||
let updated_settings = SystemSettings::update(&db, current_settings)
|
||||
.await
|
||||
.expect("Failed to update settings");
|
||||
|
||||
assert_eq!(
|
||||
updated_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should reflect the new embedding dimension"
|
||||
);
|
||||
|
||||
let openai_client = Client::new();
|
||||
|
||||
TextChunk::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("TextChunk re-embedding should succeed on fresh DB");
|
||||
KnowledgeEntity::update_all_embeddings(&db, &openai_client, &new_model, new_dimension)
|
||||
.await
|
||||
.expect("KnowledgeEntity re-embedding should succeed on fresh DB");
|
||||
|
||||
let text_chunk_dimension =
|
||||
get_hnsw_index_dimension(&db, "text_chunk", "idx_embedding_chunks").await;
|
||||
let knowledge_dimension =
|
||||
get_hnsw_index_dimension(&db, "knowledge_entity", "idx_embedding_entities").await;
|
||||
|
||||
assert_eq!(
|
||||
text_chunk_dimension, new_dimension,
|
||||
"text_chunk index dimension should update"
|
||||
);
|
||||
assert_eq!(
|
||||
knowledge_dimension, new_dimension,
|
||||
"knowledge_entity index dimension should update"
|
||||
);
|
||||
|
||||
let persisted_settings = SystemSettings::get_current(&db)
|
||||
.await
|
||||
.expect("Failed to reload updated settings");
|
||||
assert_eq!(
|
||||
persisted_settings.embedding_dimensions, new_dimension,
|
||||
"Settings should persist new embedding dimension"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,7 +68,17 @@ impl TextChunk {
|
||||
let all_chunks: Vec<TextChunk> = db.select(Self::table_name()).await?;
|
||||
let total_chunks = all_chunks.len();
|
||||
if total_chunks == 0 {
|
||||
info!("No text chunks to update. Skipping.");
|
||||
info!("No text chunks to update. Just updating the idx");
|
||||
|
||||
let mut transaction_query = String::from("BEGIN TRANSACTION;");
|
||||
transaction_query.push_str("REMOVE INDEX idx_embedding_chunks ON TABLE text_chunk;");
|
||||
transaction_query.push_str(&format!(
|
||||
"DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION {};",
|
||||
new_dimensions));
|
||||
transaction_query.push_str("COMMIT TRANSACTION;");
|
||||
|
||||
db.query(transaction_query).await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
info!("Found {} chunks to process.", total_chunks);
|
||||
|
||||
@@ -8,7 +8,7 @@ use uuid::Uuid;
|
||||
use super::text_chunk::TextChunk;
|
||||
use super::{
|
||||
conversation::Conversation,
|
||||
ingestion_task::{IngestionTask, MAX_ATTEMPTS},
|
||||
ingestion_task::{IngestionTask, TaskState},
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
system_settings::SystemSettings,
|
||||
@@ -535,19 +535,43 @@ impl User {
|
||||
let jobs: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE user_id = $user_id
|
||||
AND (
|
||||
status.name = 'Created'
|
||||
OR (
|
||||
status.name = 'InProgress'
|
||||
AND status.attempts < $max_attempts
|
||||
)
|
||||
)
|
||||
ORDER BY created_at DESC",
|
||||
WHERE user_id = $user_id
|
||||
AND (
|
||||
state IN $active_states
|
||||
OR (state = $failed_state AND attempts < max_attempts)
|
||||
)
|
||||
ORDER BY scheduled_at ASC, created_at DESC",
|
||||
)
|
||||
.bind(("table", IngestionTask::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind((
|
||||
"active_states",
|
||||
vec![
|
||||
TaskState::Pending.as_str(),
|
||||
TaskState::Reserved.as_str(),
|
||||
TaskState::Processing.as_str(),
|
||||
],
|
||||
))
|
||||
.bind(("failed_state", TaskState::Failed.as_str()))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
Ok(jobs)
|
||||
}
|
||||
|
||||
/// Gets all ingestion tasks for the specified user ordered by newest first
|
||||
pub async fn get_all_ingestion_tasks(
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<IngestionTask>, AppError> {
|
||||
let jobs: Vec<IngestionTask> = db
|
||||
.query(
|
||||
"SELECT * FROM type::table($table)
|
||||
WHERE user_id = $user_id
|
||||
ORDER BY created_at DESC",
|
||||
)
|
||||
.bind(("table", IngestionTask::table_name()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("max_attempts", MAX_ATTEMPTS))
|
||||
.await?
|
||||
.take(0)?;
|
||||
|
||||
@@ -605,7 +629,7 @@ impl User {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::storage::types::ingestion_payload::IngestionPayload;
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, IngestionTaskStatus, MAX_ATTEMPTS};
|
||||
use crate::storage::types::ingestion_task::{IngestionTask, TaskState, MAX_ATTEMPTS};
|
||||
use std::collections::HashSet;
|
||||
|
||||
// Helper function to set up a test database with SystemSettings
|
||||
@@ -710,28 +734,32 @@ mod tests {
|
||||
.await
|
||||
.expect("Failed to store created task");
|
||||
|
||||
let mut in_progress_allowed =
|
||||
IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
in_progress_allowed.status = IngestionTaskStatus::InProgress {
|
||||
attempts: 1,
|
||||
last_attempt: chrono::Utc::now(),
|
||||
};
|
||||
db.store_item(in_progress_allowed.clone())
|
||||
let mut processing_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
processing_task.state = TaskState::Processing;
|
||||
processing_task.attempts = 1;
|
||||
db.store_item(processing_task.clone())
|
||||
.await
|
||||
.expect("Failed to store in-progress task");
|
||||
.expect("Failed to store processing task");
|
||||
|
||||
let mut in_progress_blocked =
|
||||
let mut failed_retry_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
failed_retry_task.state = TaskState::Failed;
|
||||
failed_retry_task.attempts = 1;
|
||||
failed_retry_task.scheduled_at = chrono::Utc::now() - chrono::Duration::minutes(5);
|
||||
db.store_item(failed_retry_task.clone())
|
||||
.await
|
||||
.expect("Failed to store retryable failed task");
|
||||
|
||||
let mut failed_blocked_task =
|
||||
IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
in_progress_blocked.status = IngestionTaskStatus::InProgress {
|
||||
attempts: MAX_ATTEMPTS,
|
||||
last_attempt: chrono::Utc::now(),
|
||||
};
|
||||
db.store_item(in_progress_blocked.clone())
|
||||
failed_blocked_task.state = TaskState::Failed;
|
||||
failed_blocked_task.attempts = MAX_ATTEMPTS;
|
||||
failed_blocked_task.error_message = Some("Too many failures".into());
|
||||
db.store_item(failed_blocked_task.clone())
|
||||
.await
|
||||
.expect("Failed to store blocked task");
|
||||
|
||||
let mut completed_task = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
completed_task.status = IngestionTaskStatus::Completed;
|
||||
completed_task.state = TaskState::Succeeded;
|
||||
db.store_item(completed_task.clone())
|
||||
.await
|
||||
.expect("Failed to store completed task");
|
||||
@@ -755,10 +783,54 @@ mod tests {
|
||||
unfinished.iter().map(|task| task.id.clone()).collect();
|
||||
|
||||
assert!(unfinished_ids.contains(&created_task.id));
|
||||
assert!(unfinished_ids.contains(&in_progress_allowed.id));
|
||||
assert!(!unfinished_ids.contains(&in_progress_blocked.id));
|
||||
assert!(unfinished_ids.contains(&processing_task.id));
|
||||
assert!(unfinished_ids.contains(&failed_retry_task.id));
|
||||
assert!(!unfinished_ids.contains(&failed_blocked_task.id));
|
||||
assert!(!unfinished_ids.contains(&completed_task.id));
|
||||
assert_eq!(unfinished_ids.len(), 2);
|
||||
assert_eq!(unfinished_ids.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_all_ingestion_tasks_returns_sorted() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "archive_user";
|
||||
let other_user_id = "other_user";
|
||||
|
||||
let payload = IngestionPayload::Text {
|
||||
text: "One".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: user_id.to_string(),
|
||||
};
|
||||
|
||||
// Oldest task
|
||||
let mut first = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
first.created_at = first.created_at - chrono::Duration::minutes(1);
|
||||
first.updated_at = first.created_at;
|
||||
first.state = TaskState::Succeeded;
|
||||
db.store_item(first.clone()).await.expect("store first");
|
||||
|
||||
// Latest task
|
||||
let mut second = IngestionTask::new(payload.clone(), user_id.to_string()).await;
|
||||
second.state = TaskState::Processing;
|
||||
db.store_item(second.clone()).await.expect("store second");
|
||||
|
||||
let other_payload = IngestionPayload::Text {
|
||||
text: "Other".to_string(),
|
||||
context: "Context".to_string(),
|
||||
category: "Category".to_string(),
|
||||
user_id: other_user_id.to_string(),
|
||||
};
|
||||
let other_task = IngestionTask::new(other_payload, other_user_id.to_string()).await;
|
||||
db.store_item(other_task).await.expect("store other");
|
||||
|
||||
let tasks = User::get_all_ingestion_tasks(user_id, &db)
|
||||
.await
|
||||
.expect("fetch all tasks");
|
||||
|
||||
assert_eq!(tasks.len(), 2);
|
||||
assert_eq!(tasks[0].id, second.id); // newest first
|
||||
assert_eq!(tasks[1].id, first.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -11,7 +11,6 @@ use common::{
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
message::{format_history, Message},
|
||||
system_settings::SystemSettings,
|
||||
},
|
||||
@@ -20,7 +19,7 @@ use common::{
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::retrieve_entities;
|
||||
use crate::{retrieve_entities, RetrievedEntity};
|
||||
|
||||
use super::answer_retrieval_helper::get_query_response_schema;
|
||||
|
||||
@@ -84,21 +83,31 @@ pub async fn get_answer_with_references(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn format_entities_json(entities: &[KnowledgeEntity]) -> Value {
|
||||
pub fn format_entities_json(entities: &[RetrievedEntity]) -> Value {
|
||||
json!(entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
.map(|entry| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
"id": entry.entity.id,
|
||||
"name": entry.entity.name,
|
||||
"description": entry.entity.description,
|
||||
"score": round_score(entry.score),
|
||||
"chunks": entry.chunks.iter().map(|chunk| {
|
||||
json!({
|
||||
"score": round_score(chunk.score),
|
||||
"content": chunk.chunk.chunk
|
||||
})
|
||||
}).collect::<Vec<_>>()
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
fn round_score(value: f32) -> f64 {
|
||||
((value as f64) * 1000.0).round() / 1000.0
|
||||
}
|
||||
pub fn create_user_message(entities_json: &Value, query: &str) -> String {
|
||||
format!(
|
||||
r#"
|
||||
|
||||
265
composite-retrieval/src/fts.rs
Normal file
265
composite-retrieval/src/fts.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::StoredObject},
|
||||
};
|
||||
|
||||
use crate::scoring::Scored;
|
||||
use common::storage::types::file_info::deserialize_flexible_id;
|
||||
use surrealdb::sql::Thing;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FtsScoreRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
fts_score: Option<f32>,
|
||||
}
|
||||
|
||||
/// Executes a full-text search query against SurrealDB and returns scored results.
|
||||
///
|
||||
/// The function expects FTS indexes to exist for the provided table. Currently supports
|
||||
/// `knowledge_entity` (name + description) and `text_chunk` (chunk).
|
||||
pub async fn find_items_by_fts<T>(
|
||||
take: usize,
|
||||
query: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
let (filter_clause, score_clause) = match table {
|
||||
"knowledge_entity" => (
|
||||
"(name @0@ $terms OR description @1@ $terms)",
|
||||
"(IF search::score(0) != NONE THEN search::score(0) ELSE 0 END) + \
|
||||
(IF search::score(1) != NONE THEN search::score(1) ELSE 0 END)",
|
||||
),
|
||||
"text_chunk" => (
|
||||
"(chunk @0@ $terms)",
|
||||
"IF search::score(0) != NONE THEN search::score(0) ELSE 0 END",
|
||||
),
|
||||
_ => {
|
||||
return Err(AppError::Validation(format!(
|
||||
"FTS not configured for table '{table}'"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
"SELECT id, {score_clause} AS fts_score \
|
||||
FROM {table} \
|
||||
WHERE {filter_clause} \
|
||||
AND user_id = $user_id \
|
||||
ORDER BY fts_score DESC \
|
||||
LIMIT $limit",
|
||||
table = table,
|
||||
filter_clause = filter_clause,
|
||||
score_clause = score_clause
|
||||
);
|
||||
|
||||
debug!(
|
||||
table = table,
|
||||
limit = take,
|
||||
"Executing FTS query with filter clause: {}",
|
||||
filter_clause
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(sql)
|
||||
.bind(("terms", query.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let score_rows: Vec<FtsScoreRow> = response.take(0)?;
|
||||
|
||||
if score_rows.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids: Vec<String> = score_rows.iter().map(|row| row.id.clone()).collect();
|
||||
let thing_ids: Vec<Thing> = ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((table, id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut items_response = db_client
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", table.to_owned()))
|
||||
.bind(("things", thing_ids.clone()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let items: Vec<T> = items_response.take(0)?;
|
||||
|
||||
let mut item_map: HashMap<String, T> = items
|
||||
.into_iter()
|
||||
.map(|item| (item.get_id().to_owned(), item))
|
||||
.collect();
|
||||
|
||||
let mut results = Vec::with_capacity(score_rows.len());
|
||||
for row in score_rows {
|
||||
if let Some(item) = item_map.remove(&row.id) {
|
||||
let score = row.fts_score.unwrap_or_default();
|
||||
results.push(Scored::new(item).with_fts_score(score));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::{KnowledgeEntity, KnowledgeEntityType},
|
||||
text_chunk::TextChunk,
|
||||
StoredObject,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn dummy_embedding() -> Vec<f32> {
|
||||
vec![0.0; 1536]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_single_field_score_for_name() {
|
||||
let namespace = "fts_test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_a".into(),
|
||||
"Rustacean handbook".into(),
|
||||
"completely unrelated description".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to insert entity");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results = find_items_by_fts::<KnowledgeEntity>(
|
||||
5,
|
||||
"rustacean",
|
||||
&db,
|
||||
KnowledgeEntity::table_name(),
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when only the name matched"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_single_field_score_for_description() {
|
||||
let namespace = "fts_test_ns_desc";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts_desc";
|
||||
let entity = KnowledgeEntity::new(
|
||||
"source_b".into(),
|
||||
"neutral name".into(),
|
||||
"Detailed notes about async runtimes".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity.clone())
|
||||
.await
|
||||
.expect("failed to insert entity");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results = find_items_by_fts::<KnowledgeEntity>(
|
||||
5,
|
||||
"async",
|
||||
&db,
|
||||
KnowledgeEntity::table_name(),
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when only the description matched"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fts_preserves_scores_for_text_chunks() {
|
||||
let namespace = "fts_test_ns_chunks";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("failed to create in-memory surreal");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("failed to apply migrations");
|
||||
|
||||
let user_id = "user_fts_chunk";
|
||||
let chunk = TextChunk::new(
|
||||
"source_chunk".into(),
|
||||
"GraphQL documentation reference".into(),
|
||||
dummy_embedding(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk.clone())
|
||||
.await
|
||||
.expect("failed to insert chunk");
|
||||
|
||||
db.rebuild_indexes()
|
||||
.await
|
||||
.expect("failed to rebuild indexes");
|
||||
|
||||
let results =
|
||||
find_items_by_fts::<TextChunk>(5, "graphql", &db, TextChunk::table_name(), user_id)
|
||||
.await
|
||||
.expect("fts query failed");
|
||||
|
||||
assert!(!results.is_empty(), "expected at least one FTS result");
|
||||
assert!(
|
||||
results[0].scores.fts.is_some(),
|
||||
"expected an FTS score when chunk field matched"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,14 @@
|
||||
use surrealdb::Error;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEntity};
|
||||
use surrealdb::{sql::Thing, Error};
|
||||
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
knowledge_entity::KnowledgeEntity, knowledge_relationship::KnowledgeRelationship,
|
||||
StoredObject,
|
||||
},
|
||||
};
|
||||
|
||||
/// Retrieves database entries that match a specific source identifier.
|
||||
///
|
||||
@@ -30,18 +38,21 @@ use common::storage::{db::SurrealDbClient, types::knowledge_entity::KnowledgeEnt
|
||||
/// * The database query fails to execute
|
||||
/// * The results cannot be deserialized into type `T`
|
||||
pub async fn find_entities_by_source_ids<T>(
|
||||
source_id: Vec<String>,
|
||||
table_name: String,
|
||||
source_ids: Vec<String>,
|
||||
table_name: &str,
|
||||
user_id: &str,
|
||||
db: &SurrealDbClient,
|
||||
) -> Result<Vec<T>, Error>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
{
|
||||
let query = "SELECT * FROM type::table($table) WHERE source_id IN $source_ids";
|
||||
let query =
|
||||
"SELECT * FROM type::table($table) WHERE source_id IN $source_ids AND user_id = $user_id";
|
||||
|
||||
db.query(query)
|
||||
.bind(("table", table_name))
|
||||
.bind(("source_ids", source_id))
|
||||
.bind(("table", table_name.to_owned()))
|
||||
.bind(("source_ids", source_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?
|
||||
.take(0)
|
||||
}
|
||||
@@ -49,14 +60,92 @@ where
|
||||
/// Find entities by their relationship to the id
|
||||
pub async fn find_entities_by_relationship_by_id(
|
||||
db: &SurrealDbClient,
|
||||
entity_id: String,
|
||||
entity_id: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<KnowledgeEntity>, Error> {
|
||||
let query = format!(
|
||||
"SELECT *, <-> relates_to <-> knowledge_entity AS related FROM knowledge_entity:`{}`",
|
||||
entity_id
|
||||
);
|
||||
let mut relationships_response = db
|
||||
.query(
|
||||
"
|
||||
SELECT * FROM relates_to
|
||||
WHERE metadata.user_id = $user_id
|
||||
AND (in = type::thing('knowledge_entity', $entity_id)
|
||||
OR out = type::thing('knowledge_entity', $entity_id))
|
||||
",
|
||||
)
|
||||
.bind(("entity_id", entity_id.to_owned()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
db.query(query).await?.take(0)
|
||||
let relationships: Vec<KnowledgeRelationship> = relationships_response.take(0)?;
|
||||
if relationships.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_ids: Vec<String> = Vec::new();
|
||||
let mut seen: HashSet<String> = HashSet::new();
|
||||
for rel in relationships {
|
||||
if rel.in_ == entity_id {
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
} else if rel.out == entity_id {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_);
|
||||
}
|
||||
} else {
|
||||
if seen.insert(rel.in_.clone()) {
|
||||
neighbor_ids.push(rel.in_.clone());
|
||||
}
|
||||
if seen.insert(rel.out.clone()) {
|
||||
neighbor_ids.push(rel.out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
neighbor_ids.retain(|id| id != entity_id);
|
||||
|
||||
if neighbor_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
if limit > 0 && neighbor_ids.len() > limit {
|
||||
neighbor_ids.truncate(limit);
|
||||
}
|
||||
|
||||
let thing_ids: Vec<Thing> = neighbor_ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((KnowledgeEntity::table_name(), id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut neighbors_response = db
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", KnowledgeEntity::table_name().to_owned()))
|
||||
.bind(("things", thing_ids))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let neighbors: Vec<KnowledgeEntity> = neighbors_response.take(0)?;
|
||||
if neighbors.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut neighbor_map: HashMap<String, KnowledgeEntity> = neighbors
|
||||
.into_iter()
|
||||
.map(|entity| (entity.id.clone(), entity))
|
||||
.collect();
|
||||
|
||||
let mut ordered = Vec::new();
|
||||
for id in neighbor_ids {
|
||||
if let Some(entity) = neighbor_map.remove(&id) {
|
||||
ordered.push(entity);
|
||||
}
|
||||
if limit > 0 && ordered.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ordered)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -146,7 +235,7 @@ mod tests {
|
||||
// Test finding entities by multiple source_ids
|
||||
let source_ids = vec![source_id1.clone(), source_id2.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name().to_string(), &db)
|
||||
find_entities_by_source_ids(source_ids, KnowledgeEntity::table_name(), &user_id, &db)
|
||||
.await
|
||||
.expect("Failed to find entities by source_ids");
|
||||
|
||||
@@ -177,7 +266,8 @@ mod tests {
|
||||
let single_source_id = vec![source_id1.clone()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
single_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
KnowledgeEntity::table_name(),
|
||||
&user_id,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
@@ -202,7 +292,8 @@ mod tests {
|
||||
let non_existent_source_id = vec!["non_existent_source".to_string()];
|
||||
let found_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
non_existent_source_id,
|
||||
KnowledgeEntity::table_name().to_string(),
|
||||
KnowledgeEntity::table_name(),
|
||||
&user_id,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
@@ -327,11 +418,15 @@ mod tests {
|
||||
.expect("Failed to store relationship 2");
|
||||
|
||||
// Test finding entities related to the central entity
|
||||
let related_entities = find_entities_by_relationship_by_id(&db, central_entity.id.clone())
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
let related_entities =
|
||||
find_entities_by_relationship_by_id(&db, ¢ral_entity.id, &user_id, usize::MAX)
|
||||
.await
|
||||
.expect("Failed to find entities by relationship");
|
||||
|
||||
// Check that we found relationships
|
||||
assert!(related_entities.len() > 0, "Should find related entities");
|
||||
assert!(
|
||||
related_entities.len() >= 2,
|
||||
"Should find related entities in both directions"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,90 +1,721 @@
|
||||
pub mod answer_retrieval;
|
||||
pub mod answer_retrieval_helper;
|
||||
pub mod fts;
|
||||
pub mod graph;
|
||||
pub mod scoring;
|
||||
pub mod vector;
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk},
|
||||
types::{knowledge_entity::KnowledgeEntity, text_chunk::TextChunk, StoredObject},
|
||||
},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use futures::future::{try_join, try_join_all};
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use graph::{find_entities_by_relationship_by_id, find_entities_by_source_ids};
|
||||
use std::collections::HashMap;
|
||||
use vector::find_items_by_vector_similarity;
|
||||
use scoring::{
|
||||
clamp_unit, fuse_scores, merge_scored_by_id, min_max_normalize, sort_by_fused_desc,
|
||||
FusionWeights, Scored,
|
||||
};
|
||||
use tracing::{debug, instrument, trace};
|
||||
|
||||
/// Performs a comprehensive knowledge entity retrieval using multiple search strategies
|
||||
/// to find the most relevant entities for a given query.
|
||||
///
|
||||
/// # Strategy
|
||||
/// The function employs a three-pronged approach to knowledge retrieval:
|
||||
/// 1. Direct vector similarity search on knowledge entities
|
||||
/// 2. Text chunk similarity search with source entity lookup
|
||||
/// 3. Graph relationship traversal from related entities
|
||||
///
|
||||
/// This combined approach ensures both semantic similarity matches and structurally
|
||||
/// related content are included in the results.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `db_client` - SurrealDB client for database operations
|
||||
/// * `openai_client` - OpenAI client for vector embeddings generation
|
||||
/// * `query` - The search query string to find relevant knowledge entities
|
||||
/// * 'user_id' - The user id of the current user
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Vec<KnowledgeEntity>, AppError>` - A deduplicated vector of relevant
|
||||
/// knowledge entities, or an error if the retrieval process fails
|
||||
use crate::{fts::find_items_by_fts, vector::find_items_by_vector_similarity_with_embedding};
|
||||
|
||||
// Tunable knobs controlling first-pass recall, graph expansion, and answer shaping.
|
||||
const ENTITY_VECTOR_TAKE: usize = 15;
|
||||
const CHUNK_VECTOR_TAKE: usize = 20;
|
||||
const ENTITY_FTS_TAKE: usize = 10;
|
||||
const CHUNK_FTS_TAKE: usize = 20;
|
||||
const SCORE_THRESHOLD: f32 = 0.35;
|
||||
const FALLBACK_MIN_RESULTS: usize = 10;
|
||||
const TOKEN_BUDGET_ESTIMATE: usize = 2800;
|
||||
const AVG_CHARS_PER_TOKEN: usize = 4;
|
||||
const MAX_CHUNKS_PER_ENTITY: usize = 4;
|
||||
const GRAPH_TRAVERSAL_SEED_LIMIT: usize = 5;
|
||||
const GRAPH_NEIGHBOR_LIMIT: usize = 6;
|
||||
const GRAPH_SCORE_DECAY: f32 = 0.75;
|
||||
const GRAPH_SEED_MIN_SCORE: f32 = 0.4;
|
||||
const GRAPH_VECTOR_INHERITANCE: f32 = 0.6;
|
||||
|
||||
// Captures a supporting chunk plus its fused retrieval score for downstream prompts.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedChunk {
|
||||
pub chunk: TextChunk,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
// Final entity representation returned to callers, enriched with ranked chunks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetrievedEntity {
|
||||
pub entity: KnowledgeEntity,
|
||||
pub score: f32,
|
||||
pub chunks: Vec<RetrievedChunk>,
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(user_id))]
|
||||
pub async fn retrieve_entities(
|
||||
db_client: &SurrealDbClient,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
let (items_from_knowledge_entity_similarity, closest_chunks) = try_join(
|
||||
find_items_by_vector_similarity(
|
||||
10,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
trace!("Generating query embedding for hybrid retrieval");
|
||||
let query_embedding = generate_embedding(openai_client, query, db_client).await?;
|
||||
retrieve_entities_with_embedding(db_client, query_embedding, query, user_id).await
|
||||
}
|
||||
|
||||
pub(crate) async fn retrieve_entities_with_embedding(
|
||||
db_client: &SurrealDbClient,
|
||||
query_embedding: Vec<f32>,
|
||||
query: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
// 1) Gather first-pass candidates from vector search and BM25.
|
||||
let weights = FusionWeights::default();
|
||||
|
||||
let (vector_entities, vector_chunks, mut fts_entities, mut fts_chunks) = tokio::try_join!(
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
ENTITY_VECTOR_TAKE,
|
||||
query_embedding.clone(),
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
user_id,
|
||||
),
|
||||
find_items_by_vector_similarity_with_embedding(
|
||||
CHUNK_VECTOR_TAKE,
|
||||
query_embedding,
|
||||
db_client,
|
||||
"text_chunk",
|
||||
user_id,
|
||||
),
|
||||
find_items_by_fts(
|
||||
ENTITY_FTS_TAKE,
|
||||
query,
|
||||
db_client,
|
||||
"knowledge_entity",
|
||||
openai_client,
|
||||
user_id,
|
||||
user_id
|
||||
),
|
||||
find_items_by_vector_similarity(5, query, db_client, "text_chunk", openai_client, user_id),
|
||||
find_items_by_fts(CHUNK_FTS_TAKE, query, db_client, "text_chunk", user_id),
|
||||
)?;
|
||||
|
||||
debug!(
|
||||
vector_entities = vector_entities.len(),
|
||||
vector_chunks = vector_chunks.len(),
|
||||
fts_entities = fts_entities.len(),
|
||||
fts_chunks = fts_chunks.len(),
|
||||
"Hybrid retrieval initial candidate counts"
|
||||
);
|
||||
|
||||
normalize_fts_scores(&mut fts_entities);
|
||||
normalize_fts_scores(&mut fts_chunks);
|
||||
|
||||
let mut entity_candidates: HashMap<String, Scored<KnowledgeEntity>> = HashMap::new();
|
||||
let mut chunk_candidates: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||
|
||||
// Collate raw retrieval results so each ID accumulates all available signals.
|
||||
merge_scored_by_id(&mut entity_candidates, vector_entities);
|
||||
merge_scored_by_id(&mut entity_candidates, fts_entities);
|
||||
merge_scored_by_id(&mut chunk_candidates, vector_chunks);
|
||||
merge_scored_by_id(&mut chunk_candidates, fts_chunks);
|
||||
|
||||
// 2) Normalize scores, fuse them, and allow high-confidence entities to pull neighbors from the graph.
|
||||
apply_fusion(&mut entity_candidates, weights);
|
||||
apply_fusion(&mut chunk_candidates, weights);
|
||||
enrich_entities_from_graph(&mut entity_candidates, db_client, user_id, weights).await?;
|
||||
|
||||
// 3) Track high-signal chunk sources so we can backfill missing entities.
|
||||
let chunk_by_source = group_chunks_by_source(&chunk_candidates);
|
||||
let mut missing_sources = Vec::new();
|
||||
|
||||
for source_id in chunk_by_source.keys() {
|
||||
if !entity_candidates
|
||||
.values()
|
||||
.any(|entity| entity.item.source_id == *source_id)
|
||||
{
|
||||
missing_sources.push(source_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if !missing_sources.is_empty() {
|
||||
let related_entities: Vec<KnowledgeEntity> = find_entities_by_source_ids(
|
||||
missing_sources.clone(),
|
||||
"knowledge_entity",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
for entity in related_entities {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
let mut scored = Scored::new(entity.clone()).with_vector_score(best_chunk_score);
|
||||
let fused = fuse_scores(&scored.scores, weights);
|
||||
scored.update_fused(fused);
|
||||
entity_candidates.insert(entity.id.clone(), scored);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Boost entities with evidence from high scoring chunks.
|
||||
for entity in entity_candidates.values_mut() {
|
||||
if let Some(chunks) = chunk_by_source.get(&entity.item.source_id) {
|
||||
let best_chunk_score = chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.fused)
|
||||
.fold(0.0f32, f32::max);
|
||||
|
||||
if best_chunk_score > 0.0 {
|
||||
let boosted = entity.scores.vector.unwrap_or(0.0).max(best_chunk_score);
|
||||
entity.scores.vector = Some(boosted);
|
||||
let fused = fuse_scores(&entity.scores, weights);
|
||||
entity.update_fused(fused);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut entity_results: Vec<Scored<KnowledgeEntity>> =
|
||||
entity_candidates.into_values().collect();
|
||||
sort_by_fused_desc(&mut entity_results);
|
||||
|
||||
let mut filtered_entities: Vec<Scored<KnowledgeEntity>> = entity_results
|
||||
.iter()
|
||||
.filter(|candidate| candidate.fused >= SCORE_THRESHOLD)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if filtered_entities.len() < FALLBACK_MIN_RESULTS {
|
||||
// Low recall scenarios still benefit from some context; take the top N regardless of score.
|
||||
filtered_entities = entity_results
|
||||
.into_iter()
|
||||
.take(FALLBACK_MIN_RESULTS)
|
||||
.collect();
|
||||
}
|
||||
|
||||
// 4) Re-rank chunks and prepare for attachment to surviving entities.
|
||||
let mut chunk_results: Vec<Scored<TextChunk>> = chunk_candidates.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_results);
|
||||
|
||||
let mut chunk_by_id: HashMap<String, Scored<TextChunk>> = HashMap::new();
|
||||
for chunk in chunk_results {
|
||||
chunk_by_id.insert(chunk.item.id.clone(), chunk);
|
||||
}
|
||||
|
||||
enrich_chunks_from_entities(
|
||||
&mut chunk_by_id,
|
||||
&filtered_entities,
|
||||
db_client,
|
||||
user_id,
|
||||
weights,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let source_ids = closest_chunks
|
||||
.iter()
|
||||
.map(|chunk: &TextChunk| chunk.source_id.clone())
|
||||
.collect::<Vec<String>>();
|
||||
let mut chunk_values: Vec<Scored<TextChunk>> = chunk_by_id.into_values().collect();
|
||||
sort_by_fused_desc(&mut chunk_values);
|
||||
|
||||
let items_from_text_chunk_similarity: Vec<KnowledgeEntity> =
|
||||
find_entities_by_source_ids(source_ids, "knowledge_entity".to_string(), db_client).await?;
|
||||
|
||||
let items_from_relationships_futures: Vec<_> = items_from_text_chunk_similarity
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|entity| find_entities_by_relationship_by_id(db_client, entity.id.clone()))
|
||||
.collect();
|
||||
|
||||
let items_from_relationships = try_join_all(items_from_relationships_futures)
|
||||
.await?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<KnowledgeEntity>>();
|
||||
|
||||
let entities: Vec<KnowledgeEntity> = items_from_knowledge_entity_similarity
|
||||
.into_iter()
|
||||
.chain(items_from_text_chunk_similarity.into_iter())
|
||||
.chain(items_from_relationships.into_iter())
|
||||
.fold(HashMap::new(), |mut map, entity| {
|
||||
map.insert(entity.id.clone(), entity);
|
||||
map
|
||||
})
|
||||
.into_values()
|
||||
.collect();
|
||||
|
||||
Ok(entities)
|
||||
Ok(assemble_results(filtered_entities, chunk_values))
|
||||
}
|
||||
|
||||
// Minimal record used while seeding graph expansion so we can retain the original fused score.
|
||||
#[derive(Clone)]
|
||||
struct GraphSeed {
|
||||
id: String,
|
||||
fused: f32,
|
||||
}
|
||||
|
||||
async fn enrich_entities_from_graph(
|
||||
entity_candidates: &mut HashMap<String, Scored<KnowledgeEntity>>,
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
if entity_candidates.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Select a small frontier of high-confidence entities to seed the relationship walk.
|
||||
let mut seeds: Vec<GraphSeed> = entity_candidates
|
||||
.values()
|
||||
.filter(|entity| entity.fused >= GRAPH_SEED_MIN_SCORE)
|
||||
.map(|entity| GraphSeed {
|
||||
id: entity.item.id.clone(),
|
||||
fused: entity.fused,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if seeds.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Prioritise the strongest seeds so we explore the most grounded context first.
|
||||
seeds.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
seeds.truncate(GRAPH_TRAVERSAL_SEED_LIMIT);
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
for seed in seeds.clone() {
|
||||
let user_id = user_id.to_owned();
|
||||
futures.push(async move {
|
||||
// Fetch neighbors concurrently to avoid serial graph round trips.
|
||||
let neighbors = find_entities_by_relationship_by_id(
|
||||
db_client,
|
||||
&seed.id,
|
||||
&user_id,
|
||||
GRAPH_NEIGHBOR_LIMIT,
|
||||
)
|
||||
.await;
|
||||
(seed, neighbors)
|
||||
});
|
||||
}
|
||||
|
||||
while let Some((seed, neighbors_result)) = futures.next().await {
|
||||
let neighbors = neighbors_result.map_err(AppError::from)?;
|
||||
if neighbors.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fold neighbors back into the candidate map and let them inherit attenuated signal.
|
||||
for neighbor in neighbors {
|
||||
if neighbor.id == seed.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let graph_score = clamp_unit(seed.fused * GRAPH_SCORE_DECAY);
|
||||
let entry = entity_candidates
|
||||
.entry(neighbor.id.clone())
|
||||
.or_insert_with(|| Scored::new(neighbor.clone()));
|
||||
|
||||
entry.item = neighbor;
|
||||
|
||||
let inherited_vector = clamp_unit(graph_score * GRAPH_VECTOR_INHERITANCE);
|
||||
let vector_existing = entry.scores.vector.unwrap_or(0.0);
|
||||
if inherited_vector > vector_existing {
|
||||
entry.scores.vector = Some(inherited_vector);
|
||||
}
|
||||
|
||||
let existing_graph = entry.scores.graph.unwrap_or(f32::MIN);
|
||||
if graph_score > existing_graph {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
} else if entry.scores.graph.is_none() {
|
||||
entry.scores.graph = Some(graph_score);
|
||||
}
|
||||
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn normalize_fts_scores<T>(results: &mut [Scored<T>]) {
|
||||
// Scale BM25 outputs into [0,1] to keep fusion weights predictable.
|
||||
let raw_scores: Vec<f32> = results
|
||||
.iter()
|
||||
.map(|candidate| candidate.scores.fts.unwrap_or(0.0))
|
||||
.collect();
|
||||
|
||||
let normalized = min_max_normalize(&raw_scores);
|
||||
for (candidate, normalized_score) in results.iter_mut().zip(normalized.into_iter()) {
|
||||
candidate.scores.fts = Some(normalized_score);
|
||||
candidate.update_fused(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_fusion<T>(candidates: &mut HashMap<String, Scored<T>>, weights: FusionWeights)
|
||||
where
|
||||
T: StoredObject,
|
||||
{
|
||||
// Collapse individual signals into a single fused score used for ranking.
|
||||
for candidate in candidates.values_mut() {
|
||||
let fused = fuse_scores(&candidate.scores, weights);
|
||||
candidate.update_fused(fused);
|
||||
}
|
||||
}
|
||||
|
||||
fn group_chunks_by_source(
|
||||
chunks: &HashMap<String, Scored<TextChunk>>,
|
||||
) -> HashMap<String, Vec<Scored<TextChunk>>> {
|
||||
// Preserve chunk candidates keyed by their originating source entity.
|
||||
let mut by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
|
||||
for chunk in chunks.values() {
|
||||
by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk.clone());
|
||||
}
|
||||
by_source
|
||||
}
|
||||
|
||||
async fn enrich_chunks_from_entities(
|
||||
chunk_candidates: &mut HashMap<String, Scored<TextChunk>>,
|
||||
entities: &[Scored<KnowledgeEntity>],
|
||||
db_client: &SurrealDbClient,
|
||||
user_id: &str,
|
||||
weights: FusionWeights,
|
||||
) -> Result<(), AppError> {
|
||||
// Fetch additional chunks referenced by entities that survived the fusion stage.
|
||||
let mut source_ids: HashSet<String> = HashSet::new();
|
||||
for entity in entities {
|
||||
source_ids.insert(entity.item.source_id.clone());
|
||||
}
|
||||
|
||||
if source_ids.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let chunks = find_entities_by_source_ids::<TextChunk>(
|
||||
source_ids.into_iter().collect(),
|
||||
"text_chunk",
|
||||
user_id,
|
||||
db_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut entity_score_lookup: HashMap<String, f32> = HashMap::new();
|
||||
// Cache fused scores per source so chunks inherit the strength of their parent entity.
|
||||
for entity in entities {
|
||||
entity_score_lookup.insert(entity.item.source_id.clone(), entity.fused);
|
||||
}
|
||||
|
||||
for chunk in chunks {
|
||||
// Ensure each chunk is represented so downstream selection sees the latest content.
|
||||
let entry = chunk_candidates
|
||||
.entry(chunk.id.clone())
|
||||
.or_insert_with(|| Scored::new(chunk.clone()).with_vector_score(0.0));
|
||||
|
||||
let entity_score = entity_score_lookup
|
||||
.get(&chunk.source_id)
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Lift chunk score toward the entity score so supporting evidence is prioritised.
|
||||
entry.scores.vector = Some(entry.scores.vector.unwrap_or(0.0).max(entity_score * 0.8));
|
||||
let fused = fuse_scores(&entry.scores, weights);
|
||||
entry.update_fused(fused);
|
||||
entry.item = chunk;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assemble_results(
|
||||
entities: Vec<Scored<KnowledgeEntity>>,
|
||||
mut chunks: Vec<Scored<TextChunk>>,
|
||||
) -> Vec<RetrievedEntity> {
|
||||
// Re-associate chunk candidates with their parent entity for ranked selection.
|
||||
let mut chunk_by_source: HashMap<String, Vec<Scored<TextChunk>>> = HashMap::new();
|
||||
for chunk in chunks.drain(..) {
|
||||
chunk_by_source
|
||||
.entry(chunk.item.source_id.clone())
|
||||
.or_default()
|
||||
.push(chunk);
|
||||
}
|
||||
|
||||
for chunk_list in chunk_by_source.values_mut() {
|
||||
sort_by_fused_desc(chunk_list);
|
||||
}
|
||||
|
||||
let mut token_budget_remaining = TOKEN_BUDGET_ESTIMATE;
|
||||
let mut results = Vec::new();
|
||||
|
||||
for entity in entities {
|
||||
// Attach best chunks first while respecting per-entity and global token caps.
|
||||
let mut selected_chunks = Vec::new();
|
||||
if let Some(candidates) = chunk_by_source.get_mut(&entity.item.source_id) {
|
||||
let mut per_entity_count = 0;
|
||||
candidates.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
for candidate in candidates.iter() {
|
||||
if per_entity_count >= MAX_CHUNKS_PER_ENTITY {
|
||||
break;
|
||||
}
|
||||
let estimated_tokens = estimate_tokens(&candidate.item.chunk);
|
||||
if estimated_tokens > token_budget_remaining {
|
||||
continue;
|
||||
}
|
||||
|
||||
token_budget_remaining = token_budget_remaining.saturating_sub(estimated_tokens);
|
||||
per_entity_count += 1;
|
||||
|
||||
selected_chunks.push(RetrievedChunk {
|
||||
chunk: candidate.item.clone(),
|
||||
score: candidate.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
results.push(RetrievedEntity {
|
||||
entity: entity.item.clone(),
|
||||
score: entity.fused,
|
||||
chunks: selected_chunks,
|
||||
});
|
||||
|
||||
if token_budget_remaining == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
fn estimate_tokens(text: &str) -> usize {
|
||||
// Simple heuristic to avoid calling a tokenizer in hot code paths.
|
||||
let chars = text.chars().count().max(1);
|
||||
(chars / AVG_CHARS_PER_TOKEN).max(1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::storage::types::{
|
||||
knowledge_entity::KnowledgeEntityType, knowledge_relationship::KnowledgeRelationship,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
fn test_embedding() -> Vec<f32> {
|
||||
vec![0.9, 0.1, 0.0]
|
||||
}
|
||||
|
||||
fn entity_embedding_high() -> Vec<f32> {
|
||||
vec![0.8, 0.2, 0.0]
|
||||
}
|
||||
|
||||
fn entity_embedding_low() -> Vec<f32> {
|
||||
vec![0.1, 0.9, 0.0]
|
||||
}
|
||||
|
||||
fn chunk_embedding_primary() -> Vec<f32> {
|
||||
vec![0.85, 0.15, 0.0]
|
||||
}
|
||||
|
||||
fn chunk_embedding_secondary() -> Vec<f32> {
|
||||
vec![0.2, 0.8, 0.0]
|
||||
}
|
||||
|
||||
async fn setup_test_db() -> SurrealDbClient {
|
||||
let namespace = "test_ns";
|
||||
let database = &Uuid::new_v4().to_string();
|
||||
let db = SurrealDbClient::memory(namespace, database)
|
||||
.await
|
||||
.expect("Failed to start in-memory surrealdb");
|
||||
|
||||
db.apply_migrations()
|
||||
.await
|
||||
.expect("Failed to apply migrations");
|
||||
|
||||
db.query(
|
||||
"BEGIN TRANSACTION;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_chunks ON TABLE text_chunk;
|
||||
DEFINE INDEX idx_embedding_chunks ON TABLE text_chunk FIELDS embedding HNSW DIMENSION 3;
|
||||
REMOVE INDEX IF EXISTS idx_embedding_entities ON TABLE knowledge_entity;
|
||||
DEFINE INDEX idx_embedding_entities ON TABLE knowledge_entity FIELDS embedding HNSW DIMENSION 3;
|
||||
COMMIT TRANSACTION;",
|
||||
)
|
||||
.await
|
||||
.expect("Failed to redefine vector indexes for tests");
|
||||
|
||||
db
|
||||
}
|
||||
|
||||
async fn seed_test_data(db: &SurrealDbClient, user_id: &str) {
|
||||
let entity_relevant = KnowledgeEntity::new(
|
||||
"source_a".into(),
|
||||
"Rust Concurrency Patterns".into(),
|
||||
"Discussion about async concurrency in Rust.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let entity_irrelevant = KnowledgeEntity::new(
|
||||
"source_b".into(),
|
||||
"Python Tips".into(),
|
||||
"General Python programming tips.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_low(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(entity_relevant.clone())
|
||||
.await
|
||||
.expect("Failed to store relevant entity");
|
||||
db.store_item(entity_irrelevant.clone())
|
||||
.await
|
||||
.expect("Failed to store irrelevant entity");
|
||||
|
||||
let chunk_primary = TextChunk::new(
|
||||
entity_relevant.source_id.clone(),
|
||||
"Tokio enables async concurrency with lightweight tasks.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let chunk_secondary = TextChunk::new(
|
||||
entity_irrelevant.source_id.clone(),
|
||||
"Python focuses on readability and dynamic typing.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(chunk_primary)
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
db.store_item(chunk_secondary)
|
||||
.await
|
||||
.expect("Failed to store secondary chunk");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hybrid_retrieval_prioritises_relevant_entity() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "user123";
|
||||
seed_test_data(&db, user_id).await;
|
||||
|
||||
let results = retrieve_entities_with_embedding(
|
||||
&db,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
"Expected at least one retrieval result"
|
||||
);
|
||||
|
||||
let top = &results[0];
|
||||
|
||||
assert!(
|
||||
top.entity.name.contains("Rust"),
|
||||
"Expected Rust entity to be ranked first"
|
||||
);
|
||||
assert!(
|
||||
!top.chunks.is_empty(),
|
||||
"Expected Rust entity to include supporting chunks"
|
||||
);
|
||||
|
||||
let chunk_texts: Vec<&str> = top
|
||||
.chunks
|
||||
.iter()
|
||||
.map(|chunk| chunk.chunk.chunk.as_str())
|
||||
.collect();
|
||||
assert!(
|
||||
chunk_texts.iter().any(|text| text.contains("Tokio")),
|
||||
"Expected chunk discussing Tokio to be included"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_graph_relationship_enriches_results() {
|
||||
let db = setup_test_db().await;
|
||||
let user_id = "graph_user";
|
||||
|
||||
let primary = KnowledgeEntity::new(
|
||||
"primary_source".into(),
|
||||
"Async Rust patterns".into(),
|
||||
"Explores async runtimes and scheduling strategies.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_high(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor = KnowledgeEntity::new(
|
||||
"neighbor_source".into(),
|
||||
"Tokio Scheduler Deep Dive".into(),
|
||||
"Details on Tokio's cooperative scheduler.".into(),
|
||||
KnowledgeEntityType::Document,
|
||||
None,
|
||||
entity_embedding_low(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary.clone())
|
||||
.await
|
||||
.expect("Failed to store primary entity");
|
||||
db.store_item(neighbor.clone())
|
||||
.await
|
||||
.expect("Failed to store neighbor entity");
|
||||
|
||||
let primary_chunk = TextChunk::new(
|
||||
primary.source_id.clone(),
|
||||
"Rust async tasks use Tokio's cooperative scheduler.".into(),
|
||||
chunk_embedding_primary(),
|
||||
user_id.into(),
|
||||
);
|
||||
let neighbor_chunk = TextChunk::new(
|
||||
neighbor.source_id.clone(),
|
||||
"Tokio's scheduler manages task fairness across executors.".into(),
|
||||
chunk_embedding_secondary(),
|
||||
user_id.into(),
|
||||
);
|
||||
|
||||
db.store_item(primary_chunk)
|
||||
.await
|
||||
.expect("Failed to store primary chunk");
|
||||
db.store_item(neighbor_chunk)
|
||||
.await
|
||||
.expect("Failed to store neighbor chunk");
|
||||
|
||||
let relationship = KnowledgeRelationship::new(
|
||||
primary.id.clone(),
|
||||
neighbor.id.clone(),
|
||||
user_id.into(),
|
||||
"relationship_source".into(),
|
||||
"references".into(),
|
||||
);
|
||||
relationship
|
||||
.store_relationship(&db)
|
||||
.await
|
||||
.expect("Failed to store relationship");
|
||||
|
||||
let results = retrieve_entities_with_embedding(
|
||||
&db,
|
||||
test_embedding(),
|
||||
"Rust concurrency async tasks",
|
||||
user_id,
|
||||
)
|
||||
.await
|
||||
.expect("Hybrid retrieval failed");
|
||||
|
||||
let mut neighbor_entry = None;
|
||||
for entity in &results {
|
||||
if entity.entity.id == neighbor.id {
|
||||
neighbor_entry = Some(entity.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let neighbor_entry =
|
||||
neighbor_entry.expect("Graph-enriched neighbor should appear in results");
|
||||
|
||||
assert!(
|
||||
neighbor_entry.score > 0.2,
|
||||
"Graph-enriched entity should have a meaningful fused score"
|
||||
);
|
||||
assert!(
|
||||
neighbor_entry
|
||||
.chunks
|
||||
.iter()
|
||||
.all(|chunk| chunk.chunk.source_id == neighbor.source_id),
|
||||
"Neighbor entity should surface its own supporting chunks"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
180
composite-retrieval/src/scoring.rs
Normal file
180
composite-retrieval/src/scoring.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use common::storage::types::StoredObject;
|
||||
|
||||
/// Holds optional subscores gathered from different retrieval signals.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct Scores {
|
||||
pub fts: Option<f32>,
|
||||
pub vector: Option<f32>,
|
||||
pub graph: Option<f32>,
|
||||
}
|
||||
|
||||
/// Generic wrapper combining an item with its accumulated retrieval scores.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Scored<T> {
|
||||
pub item: T,
|
||||
pub scores: Scores,
|
||||
pub fused: f32,
|
||||
}
|
||||
|
||||
impl<T> Scored<T> {
|
||||
pub fn new(item: T) -> Self {
|
||||
Self {
|
||||
item,
|
||||
scores: Scores::default(),
|
||||
fused: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_vector_score(mut self, score: f32) -> Self {
|
||||
self.scores.vector = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_fts_score(mut self, score: f32) -> Self {
|
||||
self.scores.fts = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_graph_score(mut self, score: f32) -> Self {
|
||||
self.scores.graph = Some(score);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn update_fused(&mut self, fused: f32) {
|
||||
self.fused = fused;
|
||||
}
|
||||
}
|
||||
|
||||
/// Weights used for linear score fusion.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FusionWeights {
|
||||
pub vector: f32,
|
||||
pub fts: f32,
|
||||
pub graph: f32,
|
||||
pub multi_bonus: f32,
|
||||
}
|
||||
|
||||
impl Default for FusionWeights {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vector: 0.5,
|
||||
fts: 0.3,
|
||||
graph: 0.2,
|
||||
multi_bonus: 0.02,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clamp_unit(value: f32) -> f32 {
|
||||
value.max(0.0).min(1.0)
|
||||
}
|
||||
|
||||
pub fn distance_to_similarity(distance: f32) -> f32 {
|
||||
if !distance.is_finite() {
|
||||
return 0.0;
|
||||
}
|
||||
clamp_unit(1.0 / (1.0 + distance.max(0.0)))
|
||||
}
|
||||
|
||||
pub fn min_max_normalize(scores: &[f32]) -> Vec<f32> {
|
||||
if scores.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut min = f32::MAX;
|
||||
let mut max = f32::MIN;
|
||||
|
||||
for s in scores {
|
||||
if !s.is_finite() {
|
||||
continue;
|
||||
}
|
||||
if *s < min {
|
||||
min = *s;
|
||||
}
|
||||
if *s > max {
|
||||
max = *s;
|
||||
}
|
||||
}
|
||||
|
||||
if !min.is_finite() || !max.is_finite() {
|
||||
return scores.iter().map(|_| 0.0).collect();
|
||||
}
|
||||
|
||||
if (max - min).abs() < f32::EPSILON {
|
||||
return vec![1.0; scores.len()];
|
||||
}
|
||||
|
||||
scores
|
||||
.iter()
|
||||
.map(|score| {
|
||||
if !score.is_finite() {
|
||||
0.0
|
||||
} else {
|
||||
clamp_unit((score - min) / (max - min))
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn fuse_scores(scores: &Scores, weights: FusionWeights) -> f32 {
|
||||
let vector = scores.vector.unwrap_or(0.0);
|
||||
let fts = scores.fts.unwrap_or(0.0);
|
||||
let graph = scores.graph.unwrap_or(0.0);
|
||||
|
||||
let mut fused = vector * weights.vector + fts * weights.fts + graph * weights.graph;
|
||||
|
||||
let signals_present = scores
|
||||
.vector
|
||||
.iter()
|
||||
.chain(scores.fts.iter())
|
||||
.chain(scores.graph.iter())
|
||||
.count();
|
||||
if signals_present >= 2 {
|
||||
fused += weights.multi_bonus;
|
||||
}
|
||||
|
||||
clamp_unit(fused)
|
||||
}
|
||||
|
||||
pub fn merge_scored_by_id<T>(
|
||||
target: &mut std::collections::HashMap<String, Scored<T>>,
|
||||
incoming: Vec<Scored<T>>,
|
||||
) where
|
||||
T: StoredObject + Clone,
|
||||
{
|
||||
for scored in incoming {
|
||||
let id = scored.item.get_id().to_owned();
|
||||
target
|
||||
.entry(id)
|
||||
.and_modify(|existing| {
|
||||
if let Some(score) = scored.scores.vector {
|
||||
existing.scores.vector = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.fts {
|
||||
existing.scores.fts = Some(score);
|
||||
}
|
||||
if let Some(score) = scored.scores.graph {
|
||||
existing.scores.graph = Some(score);
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| Scored {
|
||||
item: scored.item.clone(),
|
||||
scores: scored.scores,
|
||||
fused: scored.fused,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sort_by_fused_desc<T>(items: &mut [Scored<T>])
|
||||
where
|
||||
T: StoredObject,
|
||||
{
|
||||
items.sort_by(|a, b| {
|
||||
b.fused
|
||||
.partial_cmp(&a.fused)
|
||||
.unwrap_or(Ordering::Equal)
|
||||
.then_with(|| a.item.get_id().cmp(b.item.get_id()))
|
||||
});
|
||||
}
|
||||
@@ -1,4 +1,15 @@
|
||||
use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::generate_embedding};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common::storage::types::file_info::deserialize_flexible_id;
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{db::SurrealDbClient, types::StoredObject},
|
||||
utils::embedding::generate_embedding,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use surrealdb::sql::Thing;
|
||||
|
||||
use crate::scoring::{clamp_unit, distance_to_similarity, Scored};
|
||||
|
||||
/// Compares vectors and retrieves a number of items from the specified table.
|
||||
///
|
||||
@@ -22,24 +33,125 @@ use common::{error::AppError, storage::db::SurrealDbClient, utils::embedding::ge
|
||||
///
|
||||
/// * `T` - The type to deserialize the query results into. Must implement `serde::Deserialize`.
|
||||
pub async fn find_items_by_vector_similarity<T>(
|
||||
take: u8,
|
||||
take: usize,
|
||||
input_text: &str,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
openai_client: &async_openai::Client<async_openai::config::OpenAIConfig>,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<T>, AppError>
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de>,
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
// Generate embeddings
|
||||
let input_embedding = generate_embedding(openai_client, input_text, db_client).await?;
|
||||
|
||||
// Construct the query
|
||||
let closest_query = format!("SELECT *, vector::distance::knn() AS distance FROM {} WHERE user_id = '{}' AND embedding <|{},40|> {:?} ORDER BY distance", table, user_id, take, input_embedding);
|
||||
|
||||
// Perform query and deserialize to struct
|
||||
let closest_entities: Vec<T> = db_client.query(closest_query).await?.take(0)?;
|
||||
|
||||
Ok(closest_entities)
|
||||
find_items_by_vector_similarity_with_embedding(take, input_embedding, db_client, table, user_id)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DistanceRow {
|
||||
#[serde(deserialize_with = "deserialize_flexible_id")]
|
||||
id: String,
|
||||
distance: Option<f32>,
|
||||
}
|
||||
|
||||
pub async fn find_items_by_vector_similarity_with_embedding<T>(
|
||||
take: usize,
|
||||
query_embedding: Vec<f32>,
|
||||
db_client: &SurrealDbClient,
|
||||
table: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<Scored<T>>, AppError>
|
||||
where
|
||||
T: for<'de> serde::Deserialize<'de> + StoredObject,
|
||||
{
|
||||
let embedding_literal = serde_json::to_string(&query_embedding)
|
||||
.map_err(|err| AppError::InternalError(format!("Failed to serialize embedding: {err}")))?;
|
||||
let closest_query = format!(
|
||||
"SELECT id, vector::distance::knn() AS distance \
|
||||
FROM {table} \
|
||||
WHERE user_id = $user_id AND embedding <|{take},40|> {embedding} \
|
||||
LIMIT $limit",
|
||||
table = table,
|
||||
take = take,
|
||||
embedding = embedding_literal
|
||||
);
|
||||
|
||||
let mut response = db_client
|
||||
.query(closest_query)
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.bind(("limit", take as i64))
|
||||
.await?;
|
||||
|
||||
let distance_rows: Vec<DistanceRow> = response.take(0)?;
|
||||
|
||||
if distance_rows.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids: Vec<String> = distance_rows.iter().map(|row| row.id.clone()).collect();
|
||||
let thing_ids: Vec<Thing> = ids
|
||||
.iter()
|
||||
.map(|id| Thing::from((table, id.as_str())))
|
||||
.collect();
|
||||
|
||||
let mut items_response = db_client
|
||||
.query("SELECT * FROM type::table($table) WHERE id IN $things AND user_id = $user_id")
|
||||
.bind(("table", table.to_owned()))
|
||||
.bind(("things", thing_ids.clone()))
|
||||
.bind(("user_id", user_id.to_owned()))
|
||||
.await?;
|
||||
|
||||
let items: Vec<T> = items_response.take(0)?;
|
||||
|
||||
let mut item_map: HashMap<String, T> = items
|
||||
.into_iter()
|
||||
.map(|item| (item.get_id().to_owned(), item))
|
||||
.collect();
|
||||
|
||||
let mut min_distance = f32::MAX;
|
||||
let mut max_distance = f32::MIN;
|
||||
|
||||
for row in &distance_rows {
|
||||
if let Some(distance) = row.distance {
|
||||
if distance.is_finite() {
|
||||
if distance < min_distance {
|
||||
min_distance = distance;
|
||||
}
|
||||
if distance > max_distance {
|
||||
max_distance = distance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let normalize = min_distance.is_finite()
|
||||
&& max_distance.is_finite()
|
||||
&& (max_distance - min_distance).abs() > f32::EPSILON;
|
||||
|
||||
let mut scored = Vec::with_capacity(distance_rows.len());
|
||||
for row in distance_rows {
|
||||
if let Some(item) = item_map.remove(&row.id) {
|
||||
let similarity = row
|
||||
.distance
|
||||
.map(|distance| {
|
||||
if normalize {
|
||||
let span = max_distance - min_distance;
|
||||
if span.abs() < f32::EPSILON {
|
||||
1.0
|
||||
} else {
|
||||
let normalized = 1.0 - ((distance - min_distance) / span);
|
||||
clamp_unit(normalized)
|
||||
}
|
||||
} else {
|
||||
distance_to_similarity(distance)
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
scored.push(Scored::new(item).with_vector_score(similarity));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -4,6 +4,7 @@ use axum::{
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::try_join;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -139,6 +140,32 @@ pub struct ActiveJobsData {
|
||||
pub user: User,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TaskArchiveEntry {
|
||||
id: String,
|
||||
state_label: String,
|
||||
state_raw: String,
|
||||
attempts: u32,
|
||||
max_attempts: u32,
|
||||
created_at: DateTime<Utc>,
|
||||
updated_at: DateTime<Utc>,
|
||||
scheduled_at: DateTime<Utc>,
|
||||
locked_at: Option<DateTime<Utc>>,
|
||||
last_error_at: Option<DateTime<Utc>>,
|
||||
error_message: Option<String>,
|
||||
worker_id: Option<String>,
|
||||
priority: i32,
|
||||
lease_duration_secs: i64,
|
||||
content_kind: String,
|
||||
content_summary: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TaskArchiveData {
|
||||
user: User,
|
||||
tasks: Vec<TaskArchiveEntry>,
|
||||
}
|
||||
|
||||
pub async fn delete_job(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
@@ -173,6 +200,70 @@ pub async fn show_active_jobs(
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn show_task_archive(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
) -> Result<impl IntoResponse, HtmlError> {
|
||||
let tasks = User::get_all_ingestion_tasks(&user.id, &state.db).await?;
|
||||
|
||||
let entries: Vec<TaskArchiveEntry> = tasks
|
||||
.into_iter()
|
||||
.map(|task| {
|
||||
let (content_kind, content_summary) = summarize_task_content(&task);
|
||||
|
||||
TaskArchiveEntry {
|
||||
id: task.id.clone(),
|
||||
state_label: task.state.display_label().to_string(),
|
||||
state_raw: task.state.as_str().to_string(),
|
||||
attempts: task.attempts,
|
||||
max_attempts: task.max_attempts,
|
||||
created_at: task.created_at,
|
||||
updated_at: task.updated_at,
|
||||
scheduled_at: task.scheduled_at,
|
||||
locked_at: task.locked_at,
|
||||
last_error_at: task.last_error_at,
|
||||
error_message: task.error_message.clone(),
|
||||
worker_id: task.worker_id.clone(),
|
||||
priority: task.priority,
|
||||
lease_duration_secs: task.lease_duration_secs,
|
||||
content_kind,
|
||||
content_summary,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(TemplateResponse::new_template(
|
||||
"dashboard/task_archive_modal.html",
|
||||
TaskArchiveData {
|
||||
user,
|
||||
tasks: entries,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
fn summarize_task_content(task: &IngestionTask) -> (String, String) {
|
||||
match &task.content {
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Text { text, .. } => {
|
||||
("Text".to_string(), truncate_summary(text, 80))
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::Url { url, .. } => {
|
||||
("URL".to_string(), url.to_string())
|
||||
}
|
||||
common::storage::types::ingestion_payload::IngestionPayload::File { file_info, .. } => {
|
||||
("File".to_string(), file_info.file_name.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_summary(input: &str, max_chars: usize) -> String {
|
||||
if input.chars().count() <= max_chars {
|
||||
input.to_string()
|
||||
} else {
|
||||
let truncated: String = input.chars().take(max_chars).collect();
|
||||
format!("{truncated}…")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_file(
|
||||
State(state): State<HtmlState>,
|
||||
RequireUser(user): RequireUser,
|
||||
|
||||
@@ -5,7 +5,9 @@ use axum::{
|
||||
routing::{delete, get},
|
||||
Router,
|
||||
};
|
||||
use handlers::{delete_job, delete_text_content, index_handler, serve_file, show_active_jobs};
|
||||
use handlers::{
|
||||
delete_job, delete_text_content, index_handler, serve_file, show_active_jobs, show_task_archive,
|
||||
};
|
||||
|
||||
use crate::html_state::HtmlState;
|
||||
|
||||
@@ -24,6 +26,7 @@ where
|
||||
{
|
||||
Router::new()
|
||||
.route("/jobs/{job_id}", delete(delete_job))
|
||||
.route("/jobs/archive", get(show_task_archive))
|
||||
.route("/active-jobs", get(show_active_jobs))
|
||||
.route("/text-content/{id}", delete(delete_text_content))
|
||||
.route("/file/{id}", get(serve_file))
|
||||
|
||||
@@ -20,7 +20,7 @@ use common::{
|
||||
storage::types::{
|
||||
file_info::FileInfo,
|
||||
ingestion_payload::IngestionPayload,
|
||||
ingestion_task::{IngestionTask, IngestionTaskStatus},
|
||||
ingestion_task::{IngestionTask, TaskState},
|
||||
user::User,
|
||||
},
|
||||
};
|
||||
@@ -178,40 +178,54 @@ pub async fn get_task_updates_stream(
|
||||
Ok(Some(updated_task)) => {
|
||||
consecutive_db_errors = 0; // Reset error count on success
|
||||
|
||||
// Format the status message based on IngestionTaskStatus
|
||||
let status_message = match &updated_task.status {
|
||||
IngestionTaskStatus::Created => "Created".to_string(),
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
// Following your template's current display
|
||||
format!("In progress, attempt {}", attempts)
|
||||
let status_message = match updated_task.state {
|
||||
TaskState::Pending => "Pending".to_string(),
|
||||
TaskState::Reserved => format!(
|
||||
"Reserved (attempt {} of {})",
|
||||
updated_task.attempts,
|
||||
updated_task.max_attempts
|
||||
),
|
||||
TaskState::Processing => format!(
|
||||
"Processing (attempt {} of {})",
|
||||
updated_task.attempts,
|
||||
updated_task.max_attempts
|
||||
),
|
||||
TaskState::Succeeded => "Completed".to_string(),
|
||||
TaskState::Failed => {
|
||||
let mut base = format!(
|
||||
"Retry scheduled (attempt {} of {})",
|
||||
updated_task.attempts,
|
||||
updated_task.max_attempts
|
||||
);
|
||||
if let Some(message) = updated_task.error_message.as_ref() {
|
||||
base.push_str(": ");
|
||||
base.push_str(message);
|
||||
}
|
||||
base
|
||||
}
|
||||
IngestionTaskStatus::Completed => "Completed".to_string(),
|
||||
IngestionTaskStatus::Error { message } => {
|
||||
// Providing a user-friendly error message from the status
|
||||
format!("Error: {}", message)
|
||||
TaskState::Cancelled => "Cancelled".to_string(),
|
||||
TaskState::DeadLetter => {
|
||||
let mut base = "Failed permanently".to_string();
|
||||
if let Some(message) = updated_task.error_message.as_ref() {
|
||||
base.push_str(": ");
|
||||
base.push_str(message);
|
||||
}
|
||||
base
|
||||
}
|
||||
IngestionTaskStatus::Cancelled => "Cancelled".to_string(),
|
||||
};
|
||||
|
||||
yield Ok(Event::default().event("status").data(status_message));
|
||||
|
||||
// Check for terminal states to close the stream
|
||||
match updated_task.status {
|
||||
IngestionTaskStatus::Completed
|
||||
| IngestionTaskStatus::Error { .. }
|
||||
| IngestionTaskStatus::Cancelled => {
|
||||
// Send a specific event that HTMX uses to close the connection
|
||||
// Send a event to reload the recent content
|
||||
// Send a event to remove the loading indicatior
|
||||
let check_icon = state.templates.render("icons/check_icon.html", &context!{}).unwrap_or("Ok".to_string());
|
||||
yield Ok(Event::default().event("stop_loading").data(check_icon));
|
||||
yield Ok(Event::default().event("update_latest_content").data("Update latest content"));
|
||||
yield Ok(Event::default().event("close_stream").data("Stream complete"));
|
||||
break; // Exit loop on terminal states
|
||||
}
|
||||
_ => {
|
||||
// Not a terminal state, continue polling
|
||||
}
|
||||
if updated_task.state.is_terminal() {
|
||||
// Send a specific event that HTMX uses to close the connection
|
||||
// Send a event to reload the recent content
|
||||
// Send a event to remove the loading indicatior
|
||||
let check_icon = state.templates.render("icons/check_icon.html", &context!{}).unwrap_or("Ok".to_string());
|
||||
yield Ok(Event::default().event("stop_loading").data(check_icon));
|
||||
yield Ok(Event::default().event("update_latest_content").data("Update latest content"));
|
||||
yield Ok(Event::default().event("close_stream").data("Stream complete"));
|
||||
break; // Exit loop on terminal states
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
|
||||
@@ -2,10 +2,16 @@
|
||||
<section id="active_jobs_section" class="nb-panel p-4 space-y-4 mt-6 sm:mt-8">
|
||||
<header class="flex flex-wrap items-center justify-between gap-3">
|
||||
<h2 class="text-xl font-extrabold tracking-tight">Active Tasks</h2>
|
||||
<button class="nb-btn btn-square btn-sm" hx-get="/active-jobs" hx-target="#active_jobs_section" hx-swap="outerHTML"
|
||||
aria-label="Refresh active tasks">
|
||||
{% include "icons/refresh_icon.html" %}
|
||||
</button>
|
||||
<div class="flex gap-2">
|
||||
<button class="nb-btn btn-square btn-sm" hx-get="/active-jobs" hx-target="#active_jobs_section" hx-swap="outerHTML"
|
||||
aria-label="Refresh active tasks">
|
||||
{% include "icons/refresh_icon.html" %}
|
||||
</button>
|
||||
<button class="nb-btn btn-sm" hx-get="/jobs/archive" hx-target="#modal" hx-swap="innerHTML"
|
||||
aria-label="View task archive">
|
||||
View Archive
|
||||
</button>
|
||||
</div>
|
||||
</header>
|
||||
{% if active_jobs %}
|
||||
<ul class="flex flex-col gap-3 list-none p-0 m-0">
|
||||
@@ -23,12 +29,18 @@
|
||||
</div>
|
||||
<div class="space-y-1">
|
||||
<div class="text-sm font-semibold">
|
||||
{% if item.status.name == "InProgress" %}
|
||||
In progress, attempt {{ item.status.attempts }}
|
||||
{% elif item.status.name == "Error" %}
|
||||
Error: {{ item.status.message }}
|
||||
{% if item.state == "Processing" %}
|
||||
Processing, attempt {{ item.attempts }} of {{ item.max_attempts }}
|
||||
{% elif item.state == "Reserved" %}
|
||||
Reserved, attempt {{ item.attempts }} of {{ item.max_attempts }}
|
||||
{% elif item.state == "Failed" %}
|
||||
Retry scheduled (attempt {{ item.attempts }} of {{ item.max_attempts }}){% if item.error_message %}: {{ item.error_message }}{% endif %}
|
||||
{% elif item.state == "DeadLetter" %}
|
||||
Failed permanently{% if item.error_message %}: {{ item.error_message }}{% endif %}
|
||||
{% elif item.state == "Succeeded" %}
|
||||
Completed
|
||||
{% else %}
|
||||
{{ item.status.name }}
|
||||
{{ item.state }}
|
||||
{% endif %}
|
||||
</div>
|
||||
<div class="text-xs font-semibold opacity-60">
|
||||
@@ -60,4 +72,4 @@
|
||||
</ul>
|
||||
{% endif %}
|
||||
</section>
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
</div>
|
||||
<div class="space-y-1">
|
||||
<div class="text-sm font-semibold flex gap-2 items-center">
|
||||
<span sse-swap="status" hx-swap="innerHTML">Created</span>
|
||||
<span sse-swap="status" hx-swap="innerHTML">Pending</span>
|
||||
<div hx-get="/content/recent" hx-target="#latest_content_section" hx-swap="outerHTML"
|
||||
hx-trigger="sse:update_latest_content"></div>
|
||||
</div>
|
||||
|
||||
152
html-router/templates/dashboard/task_archive_modal.html
Normal file
152
html-router/templates/dashboard/task_archive_modal.html
Normal file
@@ -0,0 +1,152 @@
|
||||
{% extends "modal_base.html" %}
|
||||
|
||||
{% block modal_class %}w-11/12 max-w-[90ch] max-h-[95%] overflow-y-auto{% endblock %}
|
||||
|
||||
{% block form_attributes %}onsubmit="event.preventDefault();"{% endblock %}
|
||||
|
||||
{% block modal_content %}
|
||||
<h3 class="text-xl font-extrabold tracking-tight flex items-center gap-2">
|
||||
Ingestion Task Archive
|
||||
<span class="badge badge-neutral text-xs font-normal">{{ tasks|length }} total</span>
|
||||
</h3>
|
||||
<p class="text-sm opacity-70">A history of all ingestion tasks for {{ user.email }}.</p>
|
||||
|
||||
{% if tasks %}
|
||||
<div class="hidden lg:block overflow-x-auto nb-card mt-4">
|
||||
<table class="nb-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="text-left">Content</th>
|
||||
<th class="text-left">State</th>
|
||||
<th class="text-left">Attempts</th>
|
||||
<th class="text-left">Scheduled</th>
|
||||
<th class="text-left">Updated</th>
|
||||
<th class="text-left">Worker</th>
|
||||
<th class="text-left">Error</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for task in tasks %}
|
||||
<tr>
|
||||
<td>
|
||||
<div class="flex flex-col gap-1">
|
||||
<div class="text-sm font-semibold">{{ task.content_kind }}</div>
|
||||
<div class="text-xs opacity-70 break-words">{{ task.content_summary }}</div>
|
||||
<div class="text-[11px] opacity-60 lowercase tracking-wider">{{ task.id }}</div>
|
||||
</div>
|
||||
</td>
|
||||
<td>
|
||||
<span class="badge badge-primary badge-outline tracking-wide">{{ task.state_label }}</span>
|
||||
</td>
|
||||
<td>
|
||||
<div class="text-sm font-semibold">{{ task.attempts }} / {{ task.max_attempts }}</div>
|
||||
<div class="text-xs opacity-60">Priority {{ task.priority }}</div>
|
||||
</td>
|
||||
<td>
|
||||
<div class="text-sm">
|
||||
{{ task.scheduled_at|datetimeformat(format="short", tz=user.timezone) }}
|
||||
</div>
|
||||
{% if task.locked_at %}
|
||||
<div class="text-xs opacity-60">Locked {{ task.locked_at|datetimeformat(format="short", tz=user.timezone) }}
|
||||
</div>
|
||||
{% endif %}
|
||||
</td>
|
||||
<td>
|
||||
<div class="text-sm">
|
||||
{{ task.updated_at|datetimeformat(format="short", tz=user.timezone) }}
|
||||
</div>
|
||||
<div class="text-xs opacity-60">Created {{ task.created_at|datetimeformat(format="short", tz=user.timezone) }}
|
||||
</div>
|
||||
</td>
|
||||
<td>
|
||||
{% if task.worker_id %}
|
||||
<span class="text-sm font-semibold">{{ task.worker_id }}</span>
|
||||
<div class="text-xs opacity-60">Lease {{ task.lease_duration_secs }}s</div>
|
||||
{% else %}
|
||||
<span class="text-xs opacity-60">Not assigned</span>
|
||||
{% endif %}
|
||||
</td>
|
||||
<td>
|
||||
{% if task.error_message %}
|
||||
<div class="text-sm text-error font-semibold">{{ task.error_message }}</div>
|
||||
{% if task.last_error_at %}
|
||||
<div class="text-xs opacity-60">{{ task.last_error_at|datetimeformat(format="short", tz=user.timezone) }}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% else %}
|
||||
<span class="text-xs opacity-60">—</span>
|
||||
{% endif %}
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div class="lg:hidden flex flex-col gap-3 mt-4">
|
||||
{% for task in tasks %}
|
||||
<details class="nb-panel p-3 space-y-3">
|
||||
<summary class="flex items-center justify-between gap-2 text-sm font-semibold cursor-pointer">
|
||||
<span>{{ task.content_kind }}</span>
|
||||
<span class="badge badge-primary badge-outline tracking-wide">{{ task.state_label }}</span>
|
||||
</summary>
|
||||
<div class="text-xs opacity-70 break-words">{{ task.content_summary }}</div>
|
||||
<div class="text-[11px] opacity-60 lowercase tracking-wider">{{ task.id }}</div>
|
||||
<div class="grid grid-cols-1 gap-2 text-xs">
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Attempts</span>
|
||||
<span class="text-sm font-semibold">{{ task.attempts }} / {{ task.max_attempts }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Priority</span>
|
||||
<span class="text-sm font-semibold">{{ task.priority }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Scheduled</span>
|
||||
<span>{{ task.scheduled_at|datetimeformat(format="short", tz=user.timezone) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Updated</span>
|
||||
<span>{{ task.updated_at|datetimeformat(format="short", tz=user.timezone) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Created</span>
|
||||
<span>{{ task.created_at|datetimeformat(format="short", tz=user.timezone) }}</span>
|
||||
</div>
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Worker</span>
|
||||
{% if task.worker_id %}
|
||||
<span class="text-sm font-semibold">{{ task.worker_id }}</span>
|
||||
{% else %}
|
||||
<span class="opacity-60">Unassigned</span>
|
||||
{% endif %}
|
||||
</div>
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Lease</span>
|
||||
<span>{{ task.lease_duration_secs }}s</span>
|
||||
</div>
|
||||
{% if task.locked_at %}
|
||||
<div class="flex justify-between">
|
||||
<span class="opacity-60 uppercase tracking-wide">Locked</span>
|
||||
<span>{{ task.locked_at|datetimeformat(format="short", tz=user.timezone) }}</span>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% if task.error_message or task.last_error_at %}
|
||||
<div class="border-t border-base-200 pt-2 text-xs space-y-1">
|
||||
{% if task.error_message %}
|
||||
<div class="text-sm text-error font-semibold">{{ task.error_message }}</div>
|
||||
{% endif %}
|
||||
{% if task.last_error_at %}
|
||||
<div class="opacity-60">Last error {{ task.last_error_at|datetimeformat(format="short", tz=user.timezone) }}</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endif %}
|
||||
</details>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% else %}
|
||||
<p class="text-sm opacity-70 mt-4">No tasks yet. Start an ingestion to populate the archive.</p>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
||||
{% block primary_actions %}{% endblock %}
|
||||
@@ -7,13 +7,11 @@ use async_openai::types::{
|
||||
};
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{knowledge_entity::KnowledgeEntity, system_settings::SystemSettings},
|
||||
},
|
||||
storage::{db::SurrealDbClient, types::system_settings::SystemSettings},
|
||||
};
|
||||
use composite_retrieval::{
|
||||
answer_retrieval::format_entities_json, retrieve_entities, RetrievedEntity,
|
||||
};
|
||||
use composite_retrieval::retrieve_entities;
|
||||
use serde_json::json;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
@@ -61,7 +59,7 @@ impl IngestionEnricher {
|
||||
context: Option<&str>,
|
||||
text: &str,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<KnowledgeEntity>, AppError> {
|
||||
) -> Result<Vec<RetrievedEntity>, AppError> {
|
||||
let input_text = format!(
|
||||
"content: {}, category: {}, user_context: {:?}",
|
||||
text, category, context
|
||||
@@ -75,22 +73,11 @@ impl IngestionEnricher {
|
||||
category: &str,
|
||||
context: Option<&str>,
|
||||
text: &str,
|
||||
similar_entities: &[KnowledgeEntity],
|
||||
similar_entities: &[RetrievedEntity],
|
||||
) -> Result<CreateChatCompletionRequest, AppError> {
|
||||
let settings = SystemSettings::get_current(&self.db_client).await?;
|
||||
|
||||
let entities_json = json!(similar_entities
|
||||
.iter()
|
||||
.map(|entity| {
|
||||
json!({
|
||||
"KnowledgeEntity": {
|
||||
"id": entity.id,
|
||||
"name": entity.name,
|
||||
"description": entity.description
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>());
|
||||
let entities_json = format_entities_json(similar_entities);
|
||||
|
||||
let user_message = format!(
|
||||
"Category:\n{}\ncontext:\n{:?}\nContent:\n{}\nExisting KnowledgeEntities in database:\n{}",
|
||||
|
||||
@@ -3,101 +3,47 @@ pub mod pipeline;
|
||||
pub mod types;
|
||||
pub mod utils;
|
||||
|
||||
use chrono::Utc;
|
||||
use common::storage::{
|
||||
db::SurrealDbClient,
|
||||
types::ingestion_task::{IngestionTask, IngestionTaskStatus},
|
||||
types::ingestion_task::{IngestionTask, DEFAULT_LEASE_SECS},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use pipeline::IngestionPipeline;
|
||||
use std::sync::Arc;
|
||||
use surrealdb::Action;
|
||||
use tracing::{error, info};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub async fn run_worker_loop(
|
||||
db: Arc<SurrealDbClient>,
|
||||
ingestion_pipeline: Arc<IngestionPipeline>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let worker_id = format!("ingestion-worker-{}", Uuid::new_v4());
|
||||
let lease_duration = Duration::from_secs(DEFAULT_LEASE_SECS as u64);
|
||||
let idle_backoff = Duration::from_millis(500);
|
||||
|
||||
loop {
|
||||
// First, check for any unfinished tasks
|
||||
let unfinished_tasks = IngestionTask::get_unfinished_tasks(&db).await?;
|
||||
if !unfinished_tasks.is_empty() {
|
||||
info!("Found {} unfinished jobs", unfinished_tasks.len());
|
||||
for task in unfinished_tasks {
|
||||
ingestion_pipeline.process_task(task).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// If no unfinished jobs, start listening for new ones
|
||||
info!("Listening for new jobs...");
|
||||
let mut job_stream = IngestionTask::listen_for_tasks(&db).await?;
|
||||
while let Some(notification) = job_stream.next().await {
|
||||
match notification {
|
||||
Ok(notification) => {
|
||||
info!("Received notification: {:?}", notification);
|
||||
match notification.action {
|
||||
Action::Create => {
|
||||
if let Err(e) = ingestion_pipeline.process_task(notification.data).await
|
||||
{
|
||||
error!("Error processing task: {}", e);
|
||||
}
|
||||
}
|
||||
Action::Update => {
|
||||
match notification.data.status {
|
||||
IngestionTaskStatus::Completed
|
||||
| IngestionTaskStatus::Error { .. }
|
||||
| IngestionTaskStatus::Cancelled => {
|
||||
info!(
|
||||
"Skipping already completed/error/cancelled task: {}",
|
||||
notification.data.id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => {
|
||||
// Only process if this is a retry after an error, not our own update
|
||||
if let Ok(Some(current_task)) =
|
||||
db.get_item::<IngestionTask>(¬ification.data.id).await
|
||||
{
|
||||
match current_task.status {
|
||||
IngestionTaskStatus::Error { .. }
|
||||
if attempts
|
||||
< common::storage::types::ingestion_task::MAX_ATTEMPTS =>
|
||||
{
|
||||
// This is a retry after an error
|
||||
if let Err(e) =
|
||||
ingestion_pipeline.process_task(current_task).await
|
||||
{
|
||||
error!("Error processing task retry: {}", e);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
info!(
|
||||
"Skipping in-progress update for task: {}",
|
||||
notification.data.id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
IngestionTaskStatus::Created => {
|
||||
// Shouldn't happen with Update action, but process if it does
|
||||
if let Err(e) =
|
||||
ingestion_pipeline.process_task(notification.data).await
|
||||
{
|
||||
error!("Error processing task: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {} // Ignore other actions
|
||||
}
|
||||
match IngestionTask::claim_next_ready(&db, &worker_id, Utc::now(), lease_duration).await {
|
||||
Ok(Some(task)) => {
|
||||
let task_id = task.id.clone();
|
||||
info!(
|
||||
%worker_id,
|
||||
%task_id,
|
||||
attempt = task.attempts,
|
||||
"claimed ingestion task"
|
||||
);
|
||||
if let Err(err) = ingestion_pipeline.process_task(task).await {
|
||||
error!(%worker_id, %task_id, error = %err, "ingestion task failed");
|
||||
}
|
||||
Err(e) => error!("Error in job notification: {}", e),
|
||||
}
|
||||
Ok(None) => {
|
||||
sleep(idle_backoff).await;
|
||||
}
|
||||
Err(err) => {
|
||||
error!(%worker_id, error = %err, "failed to claim ingestion task");
|
||||
warn!("Backing off for 1s after claim error");
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, the stream has ended (connection lost?)
|
||||
error!("Database stream ended unexpectedly, reconnecting...");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
use std::{sync::Arc, time::Instant};
|
||||
|
||||
use chrono::Utc;
|
||||
use text_splitter::TextSplitter;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::{info, warn};
|
||||
use tracing::{info, info_span, warn};
|
||||
|
||||
use common::{
|
||||
error::AppError,
|
||||
storage::{
|
||||
db::SurrealDbClient,
|
||||
types::{
|
||||
ingestion_task::{IngestionTask, IngestionTaskStatus, MAX_ATTEMPTS},
|
||||
ingestion_task::{IngestionTask, TaskErrorInfo},
|
||||
knowledge_entity::KnowledgeEntity,
|
||||
knowledge_relationship::KnowledgeRelationship,
|
||||
text_chunk::TextChunk,
|
||||
@@ -44,47 +43,81 @@ impl IngestionPipeline {
|
||||
})
|
||||
}
|
||||
pub async fn process_task(&self, task: IngestionTask) -> Result<(), AppError> {
|
||||
let current_attempts = match task.status {
|
||||
IngestionTaskStatus::InProgress { attempts, .. } => attempts + 1,
|
||||
_ => 1,
|
||||
};
|
||||
let task_id = task.id.clone();
|
||||
let attempt = task.attempts;
|
||||
let worker_label = task
|
||||
.worker_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown-worker".to_string());
|
||||
let span = info_span!(
|
||||
"ingestion_task",
|
||||
%task_id,
|
||||
attempt,
|
||||
worker_id = %worker_label,
|
||||
state = %task.state.as_str()
|
||||
);
|
||||
let _enter = span.enter();
|
||||
let processing_task = task.mark_processing(&self.db).await?;
|
||||
|
||||
// Update status to InProgress with attempt count
|
||||
IngestionTask::update_status(
|
||||
&task.id,
|
||||
IngestionTaskStatus::InProgress {
|
||||
attempts: current_attempts,
|
||||
last_attempt: Utc::now(),
|
||||
},
|
||||
let text_content = to_text_content(
|
||||
processing_task.content.clone(),
|
||||
&self.db,
|
||||
&self.config,
|
||||
&self.openai_client,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let text_content =
|
||||
to_text_content(task.content, &self.db, &self.config, &self.openai_client).await?;
|
||||
|
||||
match self.process(&text_content).await {
|
||||
Ok(_) => {
|
||||
IngestionTask::update_status(&task.id, IngestionTaskStatus::Completed, &self.db)
|
||||
.await?;
|
||||
processing_task.mark_succeeded(&self.db).await?;
|
||||
info!(%task_id, attempt, "ingestion task succeeded");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
if current_attempts >= MAX_ATTEMPTS {
|
||||
IngestionTask::update_status(
|
||||
&task.id,
|
||||
IngestionTaskStatus::Error {
|
||||
message: format!("Max attempts reached: {}", e),
|
||||
},
|
||||
&self.db,
|
||||
)
|
||||
.await?;
|
||||
Err(err) => {
|
||||
let reason = err.to_string();
|
||||
let error_info = TaskErrorInfo {
|
||||
code: None,
|
||||
message: reason.clone(),
|
||||
};
|
||||
|
||||
if processing_task.can_retry() {
|
||||
let delay = Self::retry_delay(processing_task.attempts);
|
||||
processing_task
|
||||
.mark_failed(error_info, delay, &self.db)
|
||||
.await?;
|
||||
warn!(
|
||||
%task_id,
|
||||
attempt = processing_task.attempts,
|
||||
retry_in_secs = delay.as_secs(),
|
||||
"ingestion task failed; scheduled retry"
|
||||
);
|
||||
} else {
|
||||
processing_task
|
||||
.mark_dead_letter(error_info, &self.db)
|
||||
.await?;
|
||||
warn!(
|
||||
%task_id,
|
||||
attempt = processing_task.attempts,
|
||||
"ingestion task failed; moved to dead letter queue"
|
||||
);
|
||||
}
|
||||
Err(AppError::Processing(e.to_string()))
|
||||
|
||||
Err(AppError::Processing(reason))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_delay(attempt: u32) -> Duration {
|
||||
const BASE_SECONDS: u64 = 30;
|
||||
const MAX_SECONDS: u64 = 15 * 60;
|
||||
|
||||
let capped_attempt = attempt.saturating_sub(1).min(5) as u32;
|
||||
let multiplier = 2_u64.pow(capped_attempt);
|
||||
let delay = BASE_SECONDS * multiplier;
|
||||
|
||||
Duration::from_secs(delay.min(MAX_SECONDS))
|
||||
}
|
||||
|
||||
pub async fn process(&self, content: &TextContent) -> Result<(), AppError> {
|
||||
let now = Instant::now();
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "main"
|
||||
version = "0.2.2"
|
||||
version = "0.2.3"
|
||||
edition = "2021"
|
||||
repository = "https://github.com/perstarkse/minne"
|
||||
license = "AGPL-3.0-or-later"
|
||||
|
||||
Reference in New Issue
Block a user