feat: d3js instead of plotly, improved graph visualisation

This commit is contained in:
Per Stark
2025-09-06 20:44:41 +02:00
parent fdf29bb735
commit 153efd1a98
10 changed files with 493 additions and 311 deletions

View File

@@ -1,16 +1,11 @@
use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use axum::{
extract::{Path, Query, State},
response::IntoResponse,
Form,
Form, Json,
};
use axum_htmx::{HxBoosted, HxRequest};
use plotly::{
common::{Line, Marker, Mode},
layout::{Axis, LayoutScene},
Layout, Plot, Scatter3D,
};
use serde::{Deserialize, Serialize};
use common::storage::types::{
@@ -39,7 +34,6 @@ pub struct KnowledgeBaseData {
entities: Vec<KnowledgeEntity>,
relationships: Vec<KnowledgeRelationship>,
user: User,
plot_html: String,
entity_types: Vec<String>,
content_categories: Vec<String>,
selected_entity_type: Option<String>,
@@ -54,12 +48,9 @@ pub async fn show_knowledge_page(
HxBoosted(is_boosted): HxBoosted,
Query(mut params): Query<FilterParams>,
) -> Result<impl IntoResponse, HtmlError> {
// Normalize filters
params.entity_type = params.entity_type.take().filter(|s| !s.trim().is_empty());
params.content_category = params
.content_category
.take()
.filter(|s| !s.trim().is_empty());
// Normalize filters: treat empty or "none" as no filter
params.entity_type = normalize_filter(params.entity_type.take());
params.content_category = normalize_filter(params.content_category.take());
// Load relevant data
let entity_types = User::get_entity_types(&user.id, &state.db).await?;
@@ -77,14 +68,12 @@ pub async fn show_knowledge_page(
};
let relationships = User::get_knowledge_relationships(&user.id, &state.db).await?;
let plot_html = get_plot_html(&entities, &relationships)?;
let conversation_archive = User::get_user_conversations(&user.id, &state.db).await?;
let kb_data = KnowledgeBaseData {
entities,
relationships,
user,
plot_html,
entity_types,
content_categories,
selected_entity_type: params.entity_type.clone(),
@@ -111,170 +100,94 @@ pub async fn show_knowledge_page(
}
}
fn get_plot_html(
entities: &[KnowledgeEntity],
relationships: &[KnowledgeRelationship],
) -> Result<String, HtmlError> {
if entities.is_empty() {
return Ok(String::new());
#[derive(Serialize)]
pub struct GraphNode {
pub id: String,
pub name: String,
pub entity_type: String,
pub degree: usize,
}
#[derive(Serialize)]
pub struct GraphLink {
pub source: String,
pub target: String,
pub relationship_type: String,
}
#[derive(Serialize)]
pub struct GraphData {
pub nodes: Vec<GraphNode>,
pub links: Vec<GraphLink>,
}
pub async fn get_knowledge_graph_json(
State(state): State<HtmlState>,
RequireUser(user): RequireUser,
Query(mut params): Query<FilterParams>,
) -> Result<impl IntoResponse, HtmlError> {
// Normalize filters: treat empty or "none" as no filter
params.entity_type = normalize_filter(params.entity_type.take());
params.content_category = normalize_filter(params.content_category.take());
// Load entities based on filters
let entities: Vec<KnowledgeEntity> = match &params.content_category {
Some(cat) => {
User::get_knowledge_entities_by_content_category(&user.id, cat, &state.db).await?
}
None => match &params.entity_type {
Some(etype) => User::get_knowledge_entities_by_type(&user.id, etype, &state.db).await?,
None => User::get_knowledge_entities(&user.id, &state.db).await?,
},
};
// All relationships for user, then filter to those whose endpoints are in the set
let relationships: Vec<KnowledgeRelationship> =
User::get_knowledge_relationships(&user.id, &state.db).await?;
let entity_ids: std::collections::HashSet<String> =
entities.iter().map(|e| e.id.clone()).collect();
let mut degree_count: HashMap<String, usize> = HashMap::new();
let mut links: Vec<GraphLink> = Vec::new();
for rel in relationships.iter() {
if entity_ids.contains(&rel.in_) && entity_ids.contains(&rel.out) {
// undirected counting for degree
*degree_count.entry(rel.in_.clone()).or_insert(0) += 1;
*degree_count.entry(rel.out.clone()).or_insert(0) += 1;
links.push(GraphLink {
source: rel.out.clone(),
target: rel.in_.clone(),
relationship_type: rel.metadata.relationship_type.clone(),
});
}
}
let id_to_idx: HashMap<_, _> = entities
.iter()
.enumerate()
.map(|(i, e)| (e.id.clone(), i))
let nodes: Vec<GraphNode> = entities
.into_iter()
.map(|e| GraphNode {
id: e.id.clone(),
name: e.name.clone(),
entity_type: format!("{:?}", e.entity_type),
degree: *degree_count.get(&e.id).unwrap_or(&0),
})
.collect();
// Build adjacency list
let mut graph: Vec<Vec<usize>> = vec![Vec::new(); entities.len()];
for rel in relationships {
if let (Some(&from_idx), Some(&to_idx)) = (id_to_idx.get(&rel.out), id_to_idx.get(&rel.in_))
{
graph[from_idx].push(to_idx);
graph[to_idx].push(from_idx);
}
}
// Find clusters (connected components)
let mut visited = vec![false; entities.len()];
let mut clusters: Vec<Vec<usize>> = Vec::new();
for i in 0..entities.len() {
if !visited[i] {
let mut queue = VecDeque::new();
let mut cluster = Vec::new();
queue.push_back(i);
visited[i] = true;
while let Some(node) = queue.pop_front() {
cluster.push(node);
for &nbr in &graph[node] {
if !visited[nbr] {
visited[nbr] = true;
queue.push_back(nbr);
}
}
Ok(Json(GraphData { nodes, links }))
}
// Normalize filter parameters: convert empty strings or "none" (case-insensitive) to None
fn normalize_filter(input: Option<String>) -> Option<String> {
match input {
None => None,
Some(s) => {
let trimmed = s.trim();
if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("none") {
None
} else {
Some(trimmed.to_string())
}
clusters.push(cluster);
}
}
// Layout params
let cluster_spacing = 20.0; // Distance between clusters
let node_spacing = 3.0; // Distance between nodes within cluster
// Arrange clusters on a Fibonacci sphere (uniform 3D positioning on unit sphere)
let cluster_count = clusters.len();
let golden_angle = std::f64::consts::PI * (3.0 - (5.0f64).sqrt());
// Will hold final positions of nodes: (x,y,z)
let mut nodes_pos = vec![(0.0f64, 0.0f64, 0.0f64); entities.len()];
for (i, cluster) in clusters.iter().enumerate() {
// Position cluster center on unit sphere scaled by cluster_spacing
let theta = golden_angle * i as f64;
let z = 1.0 - (2.0 * i as f64 + 1.0) / cluster_count as f64;
let radius = (1.0 - z * z).sqrt();
let cluster_center = (
radius * theta.cos() * cluster_spacing,
radius * theta.sin() * cluster_spacing,
z * cluster_spacing,
);
// Layout nodes within cluster as small 3D grid (cube)
// Calculate cube root to determine grid side length
let cluster_size = cluster.len();
let side_len = (cluster_size as f64).cbrt().ceil() as usize;
for (pos_in_cluster, &node_idx) in cluster.iter().enumerate() {
let x_in_cluster = (pos_in_cluster % side_len) as f64;
let y_in_cluster = ((pos_in_cluster / side_len) % side_len) as f64;
let z_in_cluster = (pos_in_cluster / (side_len * side_len)) as f64;
nodes_pos[node_idx] = (
cluster_center.0 + x_in_cluster * node_spacing,
cluster_center.1 + y_in_cluster * node_spacing,
cluster_center.2 + z_in_cluster * node_spacing,
);
}
}
let (node_x, node_y, node_z): (Vec<_>, Vec<_>, Vec<_>) = nodes_pos.iter().cloned().unzip3();
// Nodes trace
let nodes_trace = Scatter3D::new(node_x, node_y, node_z)
.mode(Mode::Markers)
.marker(Marker::new().size(8).color("#1f77b4"))
.text_array(
entities
.iter()
.map(|e| e.description.clone())
.collect::<Vec<_>>(),
)
.hover_template("Entity: %{text}<extra></extra>");
// Edges traces
let mut plot = Plot::new();
for rel in relationships {
if let (Some(&from_idx), Some(&to_idx)) = (id_to_idx.get(&rel.out), id_to_idx.get(&rel.in_))
{
let edge_x = vec![nodes_pos[from_idx].0, nodes_pos[to_idx].0];
let edge_y = vec![nodes_pos[from_idx].1, nodes_pos[to_idx].1];
let edge_z = vec![nodes_pos[from_idx].2, nodes_pos[to_idx].2];
let edge_trace = Scatter3D::new(edge_x, edge_y, edge_z)
.mode(Mode::Lines)
.line(Line::new().color("#888").width(2.0))
.hover_template(format!(
"Relationship: {}<extra></extra>",
rel.metadata.relationship_type
))
.show_legend(false);
plot.add_trace(edge_trace);
}
}
plot.add_trace(nodes_trace);
// Layout scene configuration
let layout = Layout::new()
.scene(
LayoutScene::new()
.x_axis(Axis::new().visible(false))
.y_axis(Axis::new().visible(false))
.z_axis(Axis::new().visible(false))
.camera(
plotly::layout::Camera::new()
.projection(plotly::layout::ProjectionType::Perspective.into())
.eye((2.0, 2.0, 2.0).into()),
),
)
.show_legend(false)
.paper_background_color("rgba(255,255,255,0)")
.plot_background_color("rgba(255,255,255,0)");
plot.set_layout(layout);
Ok(plot.to_html())
}
// Small utility to unzip tuple3 vectors from iterators
trait Unzip3<A, B, C> {
fn unzip3(self) -> (Vec<A>, Vec<B>, Vec<C>);
}
impl<I, A, B, C> Unzip3<A, B, C> for I
where
I: Iterator<Item = (A, B, C)>,
{
fn unzip3(self) -> (Vec<A>, Vec<B>, Vec<C>) {
let (mut va, mut vb, mut vc) = (Vec::new(), Vec::new(), Vec::new());
for (a, b, c) in self {
va.push(a);
vb.push(b);
vc.push(c);
}
(va, vb, vc)
}
}
pub async fn show_edit_knowledge_entity_form(

View File

@@ -6,8 +6,9 @@ use axum::{
Router,
};
use handlers::{
delete_knowledge_entity, delete_knowledge_relationship, patch_knowledge_entity,
save_knowledge_relationship, show_edit_knowledge_entity_form, show_knowledge_page,
delete_knowledge_entity, delete_knowledge_relationship, get_knowledge_graph_json,
patch_knowledge_entity, save_knowledge_relationship, show_edit_knowledge_entity_form,
show_knowledge_page,
};
use crate::html_state::HtmlState;
@@ -19,6 +20,7 @@ where
{
Router::new()
.route("/knowledge", get(show_knowledge_page))
.route("/knowledge/graph.json", get(get_knowledge_graph_json))
.route(
"/knowledge-entity/{id}",
get(show_edit_knowledge_entity_form)