mirror of
https://github.com/perstarkse/minne.git
synced 2026-06-24 19:06:30 +02:00
197 lines
6.2 KiB
Rust
197 lines
6.2 KiB
Rust
#![allow(clippy::arithmetic_side_effects)]
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use common::storage::types::StoredObject;
|
|
|
|
use crate::types::EvaluationCandidate;
|
|
|
|
const TOKENIZER_LABEL: &str = "estimated (~chars/4; ingestion uses bert-base-cased)";
|
|
|
|
#[allow(clippy::struct_field_names)]
|
|
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
|
|
pub struct RetrievedContextStats {
|
|
pub chunk_count: usize,
|
|
pub char_count: usize,
|
|
pub token_count: usize,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub struct RetrievalContextStats {
|
|
pub tokenizer: String,
|
|
pub queries: usize,
|
|
pub total_chunks: usize,
|
|
pub total_chars: usize,
|
|
pub total_tokens: usize,
|
|
pub avg_chunks_per_query: f64,
|
|
pub avg_chars_per_query: f64,
|
|
pub avg_tokens_per_query: f64,
|
|
pub p50_tokens_per_query: usize,
|
|
pub p95_tokens_per_query: usize,
|
|
pub max_tokens_per_query: usize,
|
|
}
|
|
|
|
pub fn stats_for_candidates(candidates: &[EvaluationCandidate]) -> RetrievedContextStats {
|
|
let mut seen_chunk_ids = std::collections::HashSet::new();
|
|
let mut stats = RetrievedContextStats::default();
|
|
|
|
for candidate in candidates {
|
|
for chunk in &candidate.chunks {
|
|
let chunk_id = chunk.chunk.id().to_string();
|
|
if !seen_chunk_ids.insert(chunk_id) {
|
|
continue;
|
|
}
|
|
let text = chunk.chunk.chunk.as_str();
|
|
stats.chunk_count += 1;
|
|
stats.char_count += text.chars().count();
|
|
stats.token_count += estimate_ingestion_tokens(text);
|
|
}
|
|
}
|
|
|
|
stats
|
|
}
|
|
|
|
#[allow(clippy::cast_precision_loss)]
|
|
pub fn aggregate_context_stats(per_query: &[RetrievedContextStats]) -> RetrievalContextStats {
|
|
let queries = per_query.len();
|
|
if queries == 0 {
|
|
return RetrievalContextStats {
|
|
tokenizer: TOKENIZER_LABEL.to_string(),
|
|
queries: 0,
|
|
total_chunks: 0,
|
|
total_chars: 0,
|
|
total_tokens: 0,
|
|
avg_chunks_per_query: 0.0,
|
|
avg_chars_per_query: 0.0,
|
|
avg_tokens_per_query: 0.0,
|
|
p50_tokens_per_query: 0,
|
|
p95_tokens_per_query: 0,
|
|
max_tokens_per_query: 0,
|
|
};
|
|
}
|
|
|
|
let total_chunks: usize = per_query.iter().map(|stats| stats.chunk_count).sum();
|
|
let total_chars: usize = per_query.iter().map(|stats| stats.char_count).sum();
|
|
let total_tokens: usize = per_query.iter().map(|stats| stats.token_count).sum();
|
|
let mut tokens_per_query: Vec<usize> =
|
|
per_query.iter().map(|stats| stats.token_count).collect();
|
|
tokens_per_query.sort_unstable();
|
|
let max_tokens_per_query = *tokens_per_query.last().unwrap_or(&0);
|
|
|
|
let total_chunks_f = total_chunks as f64;
|
|
let total_chars_f = total_chars as f64;
|
|
let total_tokens_f = total_tokens as f64;
|
|
let queries_f = queries as f64;
|
|
let avg_chunks_per_query = total_chunks_f / queries_f;
|
|
let avg_chars_per_query = total_chars_f / queries_f;
|
|
let avg_tokens_per_query = total_tokens_f / queries_f;
|
|
|
|
RetrievalContextStats {
|
|
tokenizer: TOKENIZER_LABEL.to_string(),
|
|
queries,
|
|
total_chunks,
|
|
total_chars,
|
|
total_tokens,
|
|
avg_chunks_per_query,
|
|
avg_chars_per_query,
|
|
avg_tokens_per_query,
|
|
p50_tokens_per_query: percentile_usize(&tokens_per_query, 0.50),
|
|
p95_tokens_per_query: percentile_usize(&tokens_per_query, 0.95),
|
|
max_tokens_per_query,
|
|
}
|
|
}
|
|
|
|
fn estimate_ingestion_tokens(text: &str) -> usize {
|
|
let chars = text.chars().count();
|
|
if chars == 0 {
|
|
return 0;
|
|
}
|
|
chars.div_ceil(4)
|
|
}
|
|
|
|
#[allow(
|
|
clippy::cast_precision_loss,
|
|
clippy::cast_sign_loss,
|
|
clippy::cast_possible_truncation,
|
|
clippy::indexing_slicing,
|
|
clippy::arithmetic_side_effects
|
|
)]
|
|
fn percentile_usize(sorted: &[usize], fraction: f64) -> usize {
|
|
if sorted.is_empty() {
|
|
return 0;
|
|
}
|
|
let clamped = fraction.clamp(0.0, 1.0);
|
|
let index = ((sorted.len() - 1) as f64 * clamped).round() as usize;
|
|
sorted[index.min(sorted.len() - 1)]
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::Arc;
|
|
|
|
use super::*;
|
|
use common::storage::types::text_chunk::TextChunk;
|
|
use retrieval_pipeline::RetrievedChunk;
|
|
|
|
#[test]
|
|
fn deduplicates_chunks_when_counting_context() {
|
|
let shared = Arc::new(TextChunk::new(
|
|
"src".into(),
|
|
"hello world".into(),
|
|
"user".into(),
|
|
));
|
|
let candidates = vec![
|
|
EvaluationCandidate {
|
|
entity_id: "a".into(),
|
|
source_id: "src".into(),
|
|
entity_name: "A".into(),
|
|
entity_description: None,
|
|
entity_category: None,
|
|
score: 1.0,
|
|
chunks: vec![RetrievedChunk {
|
|
chunk: Arc::clone(&shared),
|
|
score: 1.0,
|
|
}],
|
|
},
|
|
EvaluationCandidate {
|
|
entity_id: "b".into(),
|
|
source_id: "src".into(),
|
|
entity_name: "B".into(),
|
|
entity_description: None,
|
|
entity_category: None,
|
|
score: 0.9,
|
|
chunks: vec![RetrievedChunk {
|
|
chunk: shared,
|
|
score: 0.9,
|
|
}],
|
|
},
|
|
];
|
|
let stats = stats_for_candidates(&candidates);
|
|
assert_eq!(stats.chunk_count, 1);
|
|
assert_eq!(stats.char_count, "hello world".chars().count());
|
|
assert_eq!(stats.token_count, 3);
|
|
}
|
|
|
|
#[test]
|
|
fn aggregates_per_query_token_totals() {
|
|
let per_query = vec![
|
|
RetrievedContextStats {
|
|
chunk_count: 2,
|
|
char_count: 100,
|
|
token_count: 40,
|
|
},
|
|
RetrievedContextStats {
|
|
chunk_count: 5,
|
|
char_count: 250,
|
|
token_count: 100,
|
|
},
|
|
];
|
|
let aggregate = aggregate_context_stats(&per_query);
|
|
assert_eq!(aggregate.queries, 2);
|
|
assert_eq!(aggregate.total_chunks, 7);
|
|
assert_eq!(aggregate.total_tokens, 140);
|
|
assert_eq!(aggregate.max_tokens_per_query, 100);
|
|
assert!((aggregate.avg_tokens_per_query - 70.0).abs() < f64::EPSILON);
|
|
}
|
|
}
|