mirror of
https://github.com/perstarkse/minne.git
synced 2026-03-18 23:44:18 +01:00
1056 lines
34 KiB
Rust
1056 lines
34 KiB
Rust
use std::{
|
|
collections::{HashMap, HashSet},
|
|
fs,
|
|
io::BufReader,
|
|
path::PathBuf,
|
|
};
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
use chrono::{DateTime, Utc};
|
|
use common::storage::types::StoredObject;
|
|
use common::storage::{
|
|
db::SurrealDbClient,
|
|
types::{
|
|
knowledge_entity::KnowledgeEntity,
|
|
knowledge_entity_embedding::KnowledgeEntityEmbedding,
|
|
knowledge_relationship::{KnowledgeRelationship, RelationshipMetadata},
|
|
text_chunk::TextChunk,
|
|
text_chunk_embedding::TextChunkEmbedding,
|
|
text_content::TextContent,
|
|
},
|
|
};
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use surrealdb::sql::Thing;
|
|
use tracing::warn;
|
|
|
|
use crate::datasets::{ConvertedParagraph, ConvertedQuestion};
|
|
|
|
pub const MANIFEST_VERSION: u32 = 2;
|
|
pub const PARAGRAPH_SHARD_VERSION: u32 = 2;
|
|
const MANIFEST_BATCH_SIZE: usize = 100;
|
|
const MANIFEST_MAX_BYTES_PER_BATCH: usize = 300_000; // default cap for non-text batches
|
|
const TEXT_CONTENT_MAX_BYTES_PER_BATCH: usize = 250_000; // text bodies can be large; limit aggressively
|
|
const MAX_BATCHES_PER_REQUEST: usize = 24;
|
|
const REQUEST_MAX_BYTES: usize = 800_000; // total payload cap per Surreal query request
|
|
|
|
fn current_manifest_version() -> u32 {
|
|
MANIFEST_VERSION
|
|
}
|
|
|
|
fn current_paragraph_shard_version() -> u32 {
|
|
PARAGRAPH_SHARD_VERSION
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct EmbeddedKnowledgeEntity {
|
|
pub entity: KnowledgeEntity,
|
|
pub embedding: Vec<f32>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct EmbeddedTextChunk {
|
|
pub chunk: TextChunk,
|
|
pub embedding: Vec<f32>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Deserialize)]
|
|
struct LegacyKnowledgeEntity {
|
|
#[serde(flatten)]
|
|
pub entity: KnowledgeEntity,
|
|
#[serde(default)]
|
|
pub embedding: Vec<f32>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Deserialize)]
|
|
struct LegacyTextChunk {
|
|
#[serde(flatten)]
|
|
pub chunk: TextChunk,
|
|
#[serde(default)]
|
|
pub embedding: Vec<f32>,
|
|
}
|
|
|
|
fn deserialize_embedded_entities<'de, D>(
|
|
deserializer: D,
|
|
) -> Result<Vec<EmbeddedKnowledgeEntity>, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
#[derive(serde::Deserialize)]
|
|
#[serde(untagged)]
|
|
enum EntityInput {
|
|
Embedded(Vec<EmbeddedKnowledgeEntity>),
|
|
Legacy(Vec<LegacyKnowledgeEntity>),
|
|
}
|
|
|
|
match EntityInput::deserialize(deserializer)? {
|
|
EntityInput::Embedded(items) => Ok(items),
|
|
EntityInput::Legacy(items) => Ok(items
|
|
.into_iter()
|
|
.map(|legacy| EmbeddedKnowledgeEntity {
|
|
entity: legacy.entity,
|
|
embedding: legacy.embedding,
|
|
})
|
|
.collect()),
|
|
}
|
|
}
|
|
|
|
fn deserialize_embedded_chunks<'de, D>(deserializer: D) -> Result<Vec<EmbeddedTextChunk>, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
#[derive(serde::Deserialize)]
|
|
#[serde(untagged)]
|
|
enum ChunkInput {
|
|
Embedded(Vec<EmbeddedTextChunk>),
|
|
Legacy(Vec<LegacyTextChunk>),
|
|
}
|
|
|
|
match ChunkInput::deserialize(deserializer)? {
|
|
ChunkInput::Embedded(items) => Ok(items),
|
|
ChunkInput::Legacy(items) => Ok(items
|
|
.into_iter()
|
|
.map(|legacy| EmbeddedTextChunk {
|
|
chunk: legacy.chunk,
|
|
embedding: legacy.embedding,
|
|
})
|
|
.collect()),
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct CorpusManifest {
|
|
#[serde(default = "current_manifest_version")]
|
|
pub version: u32,
|
|
pub metadata: CorpusMetadata,
|
|
pub paragraphs: Vec<CorpusParagraph>,
|
|
pub questions: Vec<CorpusQuestion>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct CorpusMetadata {
|
|
pub dataset_id: String,
|
|
pub dataset_label: String,
|
|
pub slice_id: String,
|
|
pub include_unanswerable: bool,
|
|
#[serde(default)]
|
|
pub require_verified_chunks: bool,
|
|
pub ingestion_fingerprint: String,
|
|
pub embedding_backend: String,
|
|
pub embedding_model: Option<String>,
|
|
pub embedding_dimension: usize,
|
|
pub converted_checksum: String,
|
|
pub generated_at: DateTime<Utc>,
|
|
pub paragraph_count: usize,
|
|
pub question_count: usize,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct CorpusParagraph {
|
|
pub paragraph_id: String,
|
|
pub title: String,
|
|
pub text_content: TextContent,
|
|
#[serde(deserialize_with = "deserialize_embedded_entities")]
|
|
pub entities: Vec<EmbeddedKnowledgeEntity>,
|
|
pub relationships: Vec<KnowledgeRelationship>,
|
|
#[serde(deserialize_with = "deserialize_embedded_chunks")]
|
|
pub chunks: Vec<EmbeddedTextChunk>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct CorpusQuestion {
|
|
pub question_id: String,
|
|
pub paragraph_id: String,
|
|
pub text_content_id: String,
|
|
pub question_text: String,
|
|
pub answers: Vec<String>,
|
|
pub is_impossible: bool,
|
|
pub matching_chunk_ids: Vec<String>,
|
|
}
|
|
|
|
pub struct CorpusHandle {
|
|
pub manifest: CorpusManifest,
|
|
pub path: PathBuf,
|
|
pub reused_ingestion: bool,
|
|
pub reused_embeddings: bool,
|
|
pub positive_reused: usize,
|
|
pub positive_ingested: usize,
|
|
pub negative_reused: usize,
|
|
pub negative_ingested: usize,
|
|
}
|
|
|
|
pub fn window_manifest(
|
|
manifest: &CorpusManifest,
|
|
offset: usize,
|
|
length: usize,
|
|
negative_multiplier: f32,
|
|
) -> Result<CorpusManifest> {
|
|
let total = manifest.questions.len();
|
|
if total == 0 {
|
|
return Err(anyhow!(
|
|
"manifest contains no questions; cannot select a window"
|
|
));
|
|
}
|
|
if offset >= total {
|
|
return Err(anyhow!(
|
|
"window offset {} exceeds manifest questions ({})",
|
|
offset,
|
|
total
|
|
));
|
|
}
|
|
let end = (offset + length).min(total);
|
|
let questions = manifest.questions[offset..end].to_vec();
|
|
|
|
let selected_positive_ids: HashSet<_> =
|
|
questions.iter().map(|q| q.paragraph_id.clone()).collect();
|
|
let positives_all: HashSet<_> = manifest
|
|
.questions
|
|
.iter()
|
|
.map(|q| q.paragraph_id.as_str())
|
|
.collect();
|
|
let available_negatives = manifest
|
|
.paragraphs
|
|
.len()
|
|
.saturating_sub(positives_all.len());
|
|
let desired_negatives =
|
|
((selected_positive_ids.len() as f32) * negative_multiplier).ceil() as usize;
|
|
let desired_negatives = desired_negatives.min(available_negatives);
|
|
|
|
let mut paragraphs = Vec::new();
|
|
let mut negative_count = 0usize;
|
|
for paragraph in &manifest.paragraphs {
|
|
if selected_positive_ids.contains(¶graph.paragraph_id) {
|
|
paragraphs.push(paragraph.clone());
|
|
} else if negative_count < desired_negatives {
|
|
paragraphs.push(paragraph.clone());
|
|
negative_count += 1;
|
|
}
|
|
}
|
|
|
|
let mut narrowed = manifest.clone();
|
|
narrowed.questions = questions;
|
|
narrowed.paragraphs = paragraphs;
|
|
narrowed.metadata.paragraph_count = narrowed.paragraphs.len();
|
|
narrowed.metadata.question_count = narrowed.questions.len();
|
|
|
|
Ok(narrowed)
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
struct RelationInsert {
|
|
#[serde(rename = "in")]
|
|
pub in_: Thing,
|
|
#[serde(rename = "out")]
|
|
pub out: Thing,
|
|
pub id: String,
|
|
pub metadata: RelationshipMetadata,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct SizedBatch<T> {
|
|
approx_bytes: usize,
|
|
items: Vec<T>,
|
|
}
|
|
|
|
struct ManifestBatches {
|
|
text_contents: Vec<SizedBatch<TextContent>>,
|
|
entities: Vec<SizedBatch<KnowledgeEntity>>,
|
|
entity_embeddings: Vec<SizedBatch<KnowledgeEntityEmbedding>>,
|
|
relationships: Vec<SizedBatch<RelationInsert>>,
|
|
chunks: Vec<SizedBatch<TextChunk>>,
|
|
chunk_embeddings: Vec<SizedBatch<TextChunkEmbedding>>,
|
|
}
|
|
|
|
fn build_manifest_batches(manifest: &CorpusManifest) -> Result<ManifestBatches> {
|
|
let mut text_contents = Vec::new();
|
|
let mut entities = Vec::new();
|
|
let mut entity_embeddings = Vec::new();
|
|
let mut relationships = Vec::new();
|
|
let mut chunks = Vec::new();
|
|
let mut chunk_embeddings = Vec::new();
|
|
|
|
let mut seen_text_content = HashSet::new();
|
|
let mut seen_entities = HashSet::new();
|
|
let mut seen_relationships = HashSet::new();
|
|
let mut seen_chunks = HashSet::new();
|
|
|
|
for paragraph in &manifest.paragraphs {
|
|
if seen_text_content.insert(paragraph.text_content.id.clone()) {
|
|
text_contents.push(paragraph.text_content.clone());
|
|
}
|
|
|
|
for embedded_entity in ¶graph.entities {
|
|
if seen_entities.insert(embedded_entity.entity.id.clone()) {
|
|
let entity = embedded_entity.entity.clone();
|
|
entities.push(entity.clone());
|
|
entity_embeddings.push(KnowledgeEntityEmbedding::new(
|
|
&entity.id,
|
|
embedded_entity.embedding.clone(),
|
|
entity.user_id.clone(),
|
|
));
|
|
}
|
|
}
|
|
|
|
for relationship in ¶graph.relationships {
|
|
if seen_relationships.insert(relationship.id.clone()) {
|
|
let table = KnowledgeEntity::table_name();
|
|
let in_id = relationship
|
|
.in_
|
|
.strip_prefix(&format!("{table}:"))
|
|
.unwrap_or(&relationship.in_);
|
|
let out_id = relationship
|
|
.out
|
|
.strip_prefix(&format!("{table}:"))
|
|
.unwrap_or(&relationship.out);
|
|
let in_thing = Thing::from((table, in_id));
|
|
let out_thing = Thing::from((table, out_id));
|
|
relationships.push(RelationInsert {
|
|
in_: in_thing,
|
|
out: out_thing,
|
|
id: relationship.id.clone(),
|
|
metadata: relationship.metadata.clone(),
|
|
});
|
|
}
|
|
}
|
|
|
|
for embedded_chunk in ¶graph.chunks {
|
|
if seen_chunks.insert(embedded_chunk.chunk.id.clone()) {
|
|
let chunk = embedded_chunk.chunk.clone();
|
|
chunks.push(chunk.clone());
|
|
chunk_embeddings.push(TextChunkEmbedding::new(
|
|
&chunk.id,
|
|
chunk.source_id.clone(),
|
|
embedded_chunk.embedding.clone(),
|
|
chunk.user_id.clone(),
|
|
));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(ManifestBatches {
|
|
text_contents: chunk_items(
|
|
&text_contents,
|
|
MANIFEST_BATCH_SIZE,
|
|
TEXT_CONTENT_MAX_BYTES_PER_BATCH,
|
|
)
|
|
.context("chunking text_content payloads")?,
|
|
entities: chunk_items(&entities, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
|
.context("chunking knowledge_entity payloads")?,
|
|
entity_embeddings: chunk_items(
|
|
&entity_embeddings,
|
|
MANIFEST_BATCH_SIZE,
|
|
MANIFEST_MAX_BYTES_PER_BATCH,
|
|
)
|
|
.context("chunking knowledge_entity_embedding payloads")?,
|
|
relationships: chunk_items(
|
|
&relationships,
|
|
MANIFEST_BATCH_SIZE,
|
|
MANIFEST_MAX_BYTES_PER_BATCH,
|
|
)
|
|
.context("chunking relationship payloads")?,
|
|
chunks: chunk_items(&chunks, MANIFEST_BATCH_SIZE, MANIFEST_MAX_BYTES_PER_BATCH)
|
|
.context("chunking text_chunk payloads")?,
|
|
chunk_embeddings: chunk_items(
|
|
&chunk_embeddings,
|
|
MANIFEST_BATCH_SIZE,
|
|
MANIFEST_MAX_BYTES_PER_BATCH,
|
|
)
|
|
.context("chunking text_chunk_embedding payloads")?,
|
|
})
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct ParagraphShard {
|
|
#[serde(default = "current_paragraph_shard_version")]
|
|
pub version: u32,
|
|
pub paragraph_id: String,
|
|
pub shard_path: String,
|
|
pub ingestion_fingerprint: String,
|
|
pub ingested_at: DateTime<Utc>,
|
|
pub title: String,
|
|
pub text_content: TextContent,
|
|
#[serde(deserialize_with = "deserialize_embedded_entities")]
|
|
pub entities: Vec<EmbeddedKnowledgeEntity>,
|
|
pub relationships: Vec<KnowledgeRelationship>,
|
|
#[serde(deserialize_with = "deserialize_embedded_chunks")]
|
|
pub chunks: Vec<EmbeddedTextChunk>,
|
|
#[serde(default)]
|
|
pub question_bindings: HashMap<String, Vec<String>>,
|
|
#[serde(default)]
|
|
pub embedding_backend: String,
|
|
#[serde(default)]
|
|
pub embedding_model: Option<String>,
|
|
#[serde(default)]
|
|
pub embedding_dimension: usize,
|
|
}
|
|
|
|
pub struct ParagraphShardStore {
|
|
base_dir: PathBuf,
|
|
}
|
|
|
|
impl ParagraphShardStore {
|
|
pub fn new(base_dir: PathBuf) -> Self {
|
|
Self { base_dir }
|
|
}
|
|
|
|
pub fn ensure_base_dir(&self) -> Result<()> {
|
|
fs::create_dir_all(&self.base_dir)
|
|
.with_context(|| format!("creating shard base dir {}", self.base_dir.display()))
|
|
}
|
|
|
|
fn resolve(&self, relative: &str) -> PathBuf {
|
|
self.base_dir.join(relative)
|
|
}
|
|
|
|
pub fn load(&self, relative: &str, fingerprint: &str) -> Result<Option<ParagraphShard>> {
|
|
let path = self.resolve(relative);
|
|
let file = match fs::File::open(&path) {
|
|
Ok(file) => file,
|
|
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
|
Err(err) => {
|
|
return Err(err).with_context(|| format!("opening shard {}", path.display()))
|
|
}
|
|
};
|
|
let reader = BufReader::new(file);
|
|
let mut shard: ParagraphShard = serde_json::from_reader(reader)
|
|
.with_context(|| format!("parsing shard {}", path.display()))?;
|
|
|
|
if shard.ingestion_fingerprint != fingerprint {
|
|
return Ok(None);
|
|
}
|
|
if shard.version != PARAGRAPH_SHARD_VERSION {
|
|
warn!(
|
|
path = %path.display(),
|
|
version = shard.version,
|
|
expected = PARAGRAPH_SHARD_VERSION,
|
|
"Upgrading shard to current version"
|
|
);
|
|
shard.version = PARAGRAPH_SHARD_VERSION;
|
|
}
|
|
shard.shard_path = relative.to_string();
|
|
Ok(Some(shard))
|
|
}
|
|
|
|
pub fn persist(&self, shard: &ParagraphShard) -> Result<()> {
|
|
let mut shard = shard.clone();
|
|
shard.version = PARAGRAPH_SHARD_VERSION;
|
|
|
|
let path = self.resolve(&shard.shard_path);
|
|
if let Some(parent) = path.parent() {
|
|
fs::create_dir_all(parent)
|
|
.with_context(|| format!("creating shard dir {}", parent.display()))?;
|
|
}
|
|
let tmp_path = path.with_extension("json.tmp");
|
|
let body = serde_json::to_vec_pretty(&shard).context("serialising paragraph shard")?;
|
|
fs::write(&tmp_path, &body)
|
|
.with_context(|| format!("writing shard tmp {}", tmp_path.display()))?;
|
|
fs::rename(&tmp_path, &path)
|
|
.with_context(|| format!("renaming shard tmp {}", path.display()))?;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl ParagraphShard {
|
|
pub fn new(
|
|
paragraph: &ConvertedParagraph,
|
|
shard_path: String,
|
|
ingestion_fingerprint: &str,
|
|
text_content: TextContent,
|
|
entities: Vec<EmbeddedKnowledgeEntity>,
|
|
relationships: Vec<KnowledgeRelationship>,
|
|
chunks: Vec<EmbeddedTextChunk>,
|
|
embedding_backend: &str,
|
|
embedding_model: Option<String>,
|
|
embedding_dimension: usize,
|
|
) -> Self {
|
|
Self {
|
|
version: PARAGRAPH_SHARD_VERSION,
|
|
paragraph_id: paragraph.id.clone(),
|
|
shard_path,
|
|
ingestion_fingerprint: ingestion_fingerprint.to_string(),
|
|
ingested_at: Utc::now(),
|
|
title: paragraph.title.clone(),
|
|
text_content,
|
|
entities,
|
|
relationships,
|
|
chunks,
|
|
question_bindings: HashMap::new(),
|
|
embedding_backend: embedding_backend.to_string(),
|
|
embedding_model,
|
|
embedding_dimension,
|
|
}
|
|
}
|
|
|
|
pub fn to_corpus_paragraph(&self) -> CorpusParagraph {
|
|
CorpusParagraph {
|
|
paragraph_id: self.paragraph_id.clone(),
|
|
title: self.title.clone(),
|
|
text_content: self.text_content.clone(),
|
|
entities: self.entities.clone(),
|
|
relationships: self.relationships.clone(),
|
|
chunks: self.chunks.clone(),
|
|
}
|
|
}
|
|
|
|
pub fn ensure_question_binding(
|
|
&mut self,
|
|
question: &ConvertedQuestion,
|
|
) -> Result<(Vec<String>, bool)> {
|
|
if let Some(existing) = self.question_bindings.get(&question.id) {
|
|
return Ok((existing.clone(), false));
|
|
}
|
|
let chunk_ids = validate_answers(&self.text_content, &self.chunks, question)?;
|
|
self.question_bindings
|
|
.insert(question.id.clone(), chunk_ids.clone());
|
|
Ok((chunk_ids, true))
|
|
}
|
|
}
|
|
|
|
fn validate_answers(
|
|
content: &TextContent,
|
|
chunks: &[EmbeddedTextChunk],
|
|
question: &ConvertedQuestion,
|
|
) -> Result<Vec<String>> {
|
|
if question.is_impossible || question.answers.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let mut matches = std::collections::BTreeSet::new();
|
|
let mut found_any = false;
|
|
let haystack = content.text.to_ascii_lowercase();
|
|
let haystack_norm = normalize_answer_text(&haystack);
|
|
for answer in &question.answers {
|
|
let needle: String = answer.to_ascii_lowercase();
|
|
let needle_norm = normalize_answer_text(&needle);
|
|
let text_match = haystack.contains(&needle)
|
|
|| (!needle_norm.is_empty() && haystack_norm.contains(&needle_norm));
|
|
if text_match {
|
|
found_any = true;
|
|
}
|
|
for chunk in chunks {
|
|
let chunk_text = chunk.chunk.chunk.to_ascii_lowercase();
|
|
let chunk_norm = normalize_answer_text(&chunk_text);
|
|
if chunk_text.contains(&needle)
|
|
|| (!needle_norm.is_empty() && chunk_norm.contains(&needle_norm))
|
|
{
|
|
matches.insert(chunk.chunk.get_id().to_string());
|
|
found_any = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
if !found_any {
|
|
Err(anyhow!(
|
|
"expected answer for question '{}' was not found in ingested content",
|
|
question.id
|
|
))
|
|
} else {
|
|
Ok(matches.into_iter().collect())
|
|
}
|
|
}
|
|
|
|
fn normalize_answer_text(text: &str) -> String {
|
|
text.chars()
|
|
.map(|ch| {
|
|
if ch.is_alphanumeric() || ch.is_whitespace() {
|
|
ch.to_ascii_lowercase()
|
|
} else {
|
|
' '
|
|
}
|
|
})
|
|
.collect::<String>()
|
|
.split_whitespace()
|
|
.collect::<Vec<_>>()
|
|
.join(" ")
|
|
}
|
|
|
|
fn chunk_items<T: Clone + Serialize>(
|
|
items: &[T],
|
|
max_items: usize,
|
|
max_bytes: usize,
|
|
) -> Result<Vec<SizedBatch<T>>> {
|
|
if items.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let mut batches = Vec::new();
|
|
let mut current = Vec::new();
|
|
let mut current_bytes = 0usize;
|
|
|
|
for item in items {
|
|
let size = serde_json::to_vec(item)
|
|
.map(|buf| buf.len())
|
|
.context("serialising batch item for sizing")?;
|
|
|
|
let would_overflow_items = !current.is_empty() && current.len() >= max_items;
|
|
let would_overflow_bytes = !current.is_empty() && current_bytes + size > max_bytes;
|
|
|
|
if would_overflow_items || would_overflow_bytes {
|
|
batches.push(SizedBatch {
|
|
approx_bytes: current_bytes.max(1),
|
|
items: std::mem::take(&mut current),
|
|
});
|
|
current_bytes = 0;
|
|
}
|
|
|
|
current_bytes += size;
|
|
current.push(item.clone());
|
|
}
|
|
|
|
if !current.is_empty() {
|
|
batches.push(SizedBatch {
|
|
approx_bytes: current_bytes.max(1),
|
|
items: current,
|
|
});
|
|
}
|
|
|
|
Ok(batches)
|
|
}
|
|
|
|
async fn execute_batched_inserts<T: Clone + Serialize + 'static>(
|
|
db: &SurrealDbClient,
|
|
statement: impl AsRef<str>,
|
|
prefix: &str,
|
|
batches: &[SizedBatch<T>],
|
|
) -> Result<()> {
|
|
if batches.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
let mut start = 0;
|
|
while start < batches.len() {
|
|
let mut group_bytes = 0usize;
|
|
let mut group_end = start;
|
|
let mut group_count = 0usize;
|
|
|
|
while group_end < batches.len() {
|
|
let batch_bytes = batches[group_end].approx_bytes.max(1);
|
|
if group_count > 0
|
|
&& (group_bytes + batch_bytes > REQUEST_MAX_BYTES
|
|
|| group_count >= MAX_BATCHES_PER_REQUEST)
|
|
{
|
|
break;
|
|
}
|
|
group_bytes += batch_bytes;
|
|
group_end += 1;
|
|
group_count += 1;
|
|
}
|
|
|
|
let slice = &batches[start..group_end];
|
|
let mut query = db.client.query("BEGIN TRANSACTION;");
|
|
let mut bind_index = 0usize;
|
|
for batch in slice {
|
|
let name = format!("{prefix}{bind_index}");
|
|
bind_index += 1;
|
|
query = query
|
|
.query(format!("{} ${};", statement.as_ref(), name))
|
|
.bind((name, batch.items.clone()));
|
|
}
|
|
let response = query
|
|
.query("COMMIT TRANSACTION;")
|
|
.await
|
|
.context("executing batched insert transaction")?;
|
|
if let Err(err) = response.check() {
|
|
return Err(anyhow!(
|
|
"batched insert failed for statement '{}': {err:?}",
|
|
statement.as_ref()
|
|
));
|
|
}
|
|
|
|
start = group_end;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn seed_manifest_into_db(db: &SurrealDbClient, manifest: &CorpusManifest) -> Result<()> {
|
|
let batches = build_manifest_batches(manifest).context("preparing manifest batches")?;
|
|
|
|
let result = (|| async {
|
|
execute_batched_inserts(
|
|
db,
|
|
format!("INSERT INTO {}", TextContent::table_name()),
|
|
"tc",
|
|
&batches.text_contents,
|
|
)
|
|
.await?;
|
|
|
|
execute_batched_inserts(
|
|
db,
|
|
format!("INSERT INTO {}", KnowledgeEntity::table_name()),
|
|
"ke",
|
|
&batches.entities,
|
|
)
|
|
.await?;
|
|
|
|
execute_batched_inserts(
|
|
db,
|
|
format!("INSERT INTO {}", TextChunk::table_name()),
|
|
"ch",
|
|
&batches.chunks,
|
|
)
|
|
.await?;
|
|
|
|
execute_batched_inserts(
|
|
db,
|
|
"INSERT RELATION INTO relates_to",
|
|
"rel",
|
|
&batches.relationships,
|
|
)
|
|
.await?;
|
|
|
|
execute_batched_inserts(
|
|
db,
|
|
format!("INSERT INTO {}", KnowledgeEntityEmbedding::table_name()),
|
|
"kee",
|
|
&batches.entity_embeddings,
|
|
)
|
|
.await?;
|
|
|
|
execute_batched_inserts(
|
|
db,
|
|
format!("INSERT INTO {}", TextChunkEmbedding::table_name()),
|
|
"tce",
|
|
&batches.chunk_embeddings,
|
|
)
|
|
.await?;
|
|
|
|
Ok(())
|
|
})()
|
|
.await;
|
|
|
|
if result.is_err() {
|
|
// Best-effort cleanup to avoid leaving partial manifest data behind.
|
|
let _ = db
|
|
.client
|
|
.query(
|
|
"BEGIN TRANSACTION;
|
|
DELETE text_chunk_embedding;
|
|
DELETE knowledge_entity_embedding;
|
|
DELETE relates_to;
|
|
DELETE text_chunk;
|
|
DELETE knowledge_entity;
|
|
DELETE text_content;
|
|
COMMIT TRANSACTION;",
|
|
)
|
|
.await;
|
|
}
|
|
|
|
result
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::db_helpers::change_embedding_length_in_hnsw_indexes;
|
|
use chrono::Utc;
|
|
use common::storage::types::knowledge_entity::KnowledgeEntityType;
|
|
use uuid::Uuid;
|
|
|
|
fn build_manifest() -> CorpusManifest {
|
|
let user_id = "user-1".to_string();
|
|
let source_id = "source-1".to_string();
|
|
let now = Utc::now();
|
|
let text_content_id = Uuid::new_v4().to_string();
|
|
|
|
let text_content = TextContent {
|
|
id: text_content_id.clone(),
|
|
created_at: now,
|
|
updated_at: now,
|
|
text: "Hello world".to_string(),
|
|
file_info: None,
|
|
url_info: None,
|
|
context: None,
|
|
category: "test".to_string(),
|
|
user_id: user_id.clone(),
|
|
};
|
|
|
|
let entity = KnowledgeEntity {
|
|
id: Uuid::new_v4().to_string(),
|
|
created_at: now,
|
|
updated_at: now,
|
|
source_id: source_id.clone(),
|
|
name: "Entity".to_string(),
|
|
description: "A test entity".to_string(),
|
|
entity_type: KnowledgeEntityType::Document,
|
|
metadata: None,
|
|
user_id: user_id.clone(),
|
|
};
|
|
let relationship = KnowledgeRelationship::new(
|
|
format!("knowledge_entity:{}", entity.id),
|
|
format!("knowledge_entity:{}", entity.id),
|
|
user_id.clone(),
|
|
source_id.clone(),
|
|
"related".to_string(),
|
|
);
|
|
|
|
let chunk = TextChunk {
|
|
id: Uuid::new_v4().to_string(),
|
|
created_at: now,
|
|
updated_at: now,
|
|
source_id: source_id.clone(),
|
|
chunk: "chunk text".to_string(),
|
|
user_id: user_id.clone(),
|
|
};
|
|
|
|
let paragraph_one = CorpusParagraph {
|
|
paragraph_id: "p1".to_string(),
|
|
title: "Paragraph 1".to_string(),
|
|
text_content: text_content.clone(),
|
|
entities: vec![EmbeddedKnowledgeEntity {
|
|
entity: entity.clone(),
|
|
embedding: vec![0.1, 0.2, 0.3],
|
|
}],
|
|
relationships: vec![relationship],
|
|
chunks: vec![EmbeddedTextChunk {
|
|
chunk: chunk.clone(),
|
|
embedding: vec![0.3, 0.2, 0.1],
|
|
}],
|
|
};
|
|
|
|
// Duplicate content/entities should be de-duplicated by the loader.
|
|
let paragraph_two = CorpusParagraph {
|
|
paragraph_id: "p2".to_string(),
|
|
title: "Paragraph 2".to_string(),
|
|
text_content: text_content.clone(),
|
|
entities: vec![EmbeddedKnowledgeEntity {
|
|
entity: entity.clone(),
|
|
embedding: vec![0.1, 0.2, 0.3],
|
|
}],
|
|
relationships: Vec::new(),
|
|
chunks: vec![EmbeddedTextChunk {
|
|
chunk: chunk.clone(),
|
|
embedding: vec![0.3, 0.2, 0.1],
|
|
}],
|
|
};
|
|
|
|
let question = CorpusQuestion {
|
|
question_id: "q1".to_string(),
|
|
paragraph_id: paragraph_one.paragraph_id.clone(),
|
|
text_content_id: text_content_id,
|
|
question_text: "What is this?".to_string(),
|
|
answers: vec!["Hello".to_string()],
|
|
is_impossible: false,
|
|
matching_chunk_ids: vec![chunk.id.clone()],
|
|
};
|
|
|
|
CorpusManifest {
|
|
version: current_manifest_version(),
|
|
metadata: CorpusMetadata {
|
|
dataset_id: "dataset".to_string(),
|
|
dataset_label: "Dataset".to_string(),
|
|
slice_id: "slice".to_string(),
|
|
include_unanswerable: false,
|
|
require_verified_chunks: false,
|
|
ingestion_fingerprint: "fp".to_string(),
|
|
embedding_backend: "test".to_string(),
|
|
embedding_model: Some("model".to_string()),
|
|
embedding_dimension: 3,
|
|
converted_checksum: "checksum".to_string(),
|
|
generated_at: now,
|
|
paragraph_count: 2,
|
|
question_count: 1,
|
|
},
|
|
paragraphs: vec![paragraph_one, paragraph_two],
|
|
questions: vec![question],
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn seeds_manifest_with_transactional_batches() {
|
|
let namespace = "test_ns";
|
|
let database = Uuid::new_v4().to_string();
|
|
let db = SurrealDbClient::memory(namespace, &database)
|
|
.await
|
|
.expect("memory db");
|
|
db.apply_migrations()
|
|
.await
|
|
.expect("apply migrations for memory db");
|
|
change_embedding_length_in_hnsw_indexes(&db, 3)
|
|
.await
|
|
.expect("set embedding index dimension for test");
|
|
|
|
let manifest = build_manifest();
|
|
seed_manifest_into_db(&db, &manifest)
|
|
.await
|
|
.expect("manifest seed should succeed");
|
|
|
|
let text_contents: Vec<TextContent> = db
|
|
.client
|
|
.query(format!("SELECT * FROM {};", TextContent::table_name()))
|
|
.await
|
|
.expect("select text_content")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
assert_eq!(text_contents.len(), 1);
|
|
|
|
let entities: Vec<KnowledgeEntity> = db
|
|
.client
|
|
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
|
|
.await
|
|
.expect("select knowledge_entity")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
assert_eq!(entities.len(), 1);
|
|
|
|
let chunks: Vec<TextChunk> = db
|
|
.client
|
|
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
|
|
.await
|
|
.expect("select text_chunk")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
assert_eq!(chunks.len(), 1);
|
|
|
|
let relationships: Vec<KnowledgeRelationship> = db
|
|
.client
|
|
.query("SELECT * FROM relates_to;")
|
|
.await
|
|
.expect("select relates_to")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
assert_eq!(relationships.len(), 1);
|
|
|
|
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
|
.client
|
|
.query(format!(
|
|
"SELECT * FROM {};",
|
|
KnowledgeEntityEmbedding::table_name()
|
|
))
|
|
.await
|
|
.expect("select knowledge_entity_embedding")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
assert_eq!(entity_embeddings.len(), 1);
|
|
|
|
let chunk_embeddings: Vec<TextChunkEmbedding> = db
|
|
.client
|
|
.query(format!(
|
|
"SELECT * FROM {};",
|
|
TextChunkEmbedding::table_name()
|
|
))
|
|
.await
|
|
.expect("select text_chunk_embedding")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
assert_eq!(chunk_embeddings.len(), 1);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn rolls_back_when_embeddings_mismatch_index_dimension() {
|
|
let namespace = "test_ns_rollback";
|
|
let database = Uuid::new_v4().to_string();
|
|
let db = SurrealDbClient::memory(namespace, &database)
|
|
.await
|
|
.expect("memory db");
|
|
db.apply_migrations()
|
|
.await
|
|
.expect("apply migrations for memory db");
|
|
|
|
let manifest = build_manifest();
|
|
let result = seed_manifest_into_db(&db, &manifest).await;
|
|
assert!(
|
|
result.is_err(),
|
|
"expected embedding dimension mismatch to fail"
|
|
);
|
|
|
|
let text_contents: Vec<TextContent> = db
|
|
.client
|
|
.query(format!("SELECT * FROM {};", TextContent::table_name()))
|
|
.await
|
|
.expect("select text_content")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
let entities: Vec<KnowledgeEntity> = db
|
|
.client
|
|
.query(format!("SELECT * FROM {};", KnowledgeEntity::table_name()))
|
|
.await
|
|
.expect("select knowledge_entity")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
let chunks: Vec<TextChunk> = db
|
|
.client
|
|
.query(format!("SELECT * FROM {};", TextChunk::table_name()))
|
|
.await
|
|
.expect("select text_chunk")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
let relationships: Vec<KnowledgeRelationship> = db
|
|
.client
|
|
.query("SELECT * FROM relates_to;")
|
|
.await
|
|
.expect("select relates_to")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
let entity_embeddings: Vec<KnowledgeEntityEmbedding> = db
|
|
.client
|
|
.query(format!(
|
|
"SELECT * FROM {};",
|
|
KnowledgeEntityEmbedding::table_name()
|
|
))
|
|
.await
|
|
.expect("select knowledge_entity_embedding")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
let chunk_embeddings: Vec<TextChunkEmbedding> = db
|
|
.client
|
|
.query(format!(
|
|
"SELECT * FROM {};",
|
|
TextChunkEmbedding::table_name()
|
|
))
|
|
.await
|
|
.expect("select text_chunk_embedding")
|
|
.take(0)
|
|
.unwrap_or_default();
|
|
|
|
assert!(
|
|
text_contents.is_empty()
|
|
&& entities.is_empty()
|
|
&& chunks.is_empty()
|
|
&& relationships.is_empty()
|
|
&& entity_embeddings.is_empty()
|
|
&& chunk_embeddings.is_empty(),
|
|
"no rows should be inserted when transaction fails"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn window_manifest_trims_questions_and_negatives() {
|
|
let manifest = build_manifest();
|
|
// Add extra negatives to simulate multiplier ~4x
|
|
let mut manifest = manifest;
|
|
let mut extra_paragraphs = Vec::new();
|
|
for _ in 0..8 {
|
|
let mut p = manifest.paragraphs[0].clone();
|
|
p.paragraph_id = Uuid::new_v4().to_string();
|
|
p.entities.clear();
|
|
p.relationships.clear();
|
|
p.chunks.clear();
|
|
extra_paragraphs.push(p);
|
|
}
|
|
manifest.paragraphs.extend(extra_paragraphs);
|
|
manifest.metadata.paragraph_count = manifest.paragraphs.len();
|
|
|
|
let windowed = window_manifest(&manifest, 0, 1, 4.0).expect("window manifest");
|
|
assert_eq!(windowed.questions.len(), 1);
|
|
// Expect roughly 4x negatives (bounded by available paragraphs)
|
|
assert!(
|
|
windowed.paragraphs.len() <= manifest.paragraphs.len(),
|
|
"windowed paragraphs should never exceed original"
|
|
);
|
|
let positive_set: std::collections::HashSet<_> = windowed
|
|
.questions
|
|
.iter()
|
|
.map(|q| q.paragraph_id.as_str())
|
|
.collect();
|
|
let positives = windowed
|
|
.paragraphs
|
|
.iter()
|
|
.filter(|p| positive_set.contains(p.paragraph_id.as_str()))
|
|
.count();
|
|
let negatives = windowed.paragraphs.len().saturating_sub(positives);
|
|
assert_eq!(positives, 1);
|
|
assert!(negatives >= 1, "should include some negatives");
|
|
}
|
|
}
|