diff --git a/src/server/routes/query.rs b/src/server/routes/query.rs index cd1705b..b540db4 100644 --- a/src/server/routes/query.rs +++ b/src/server/routes/query.rs @@ -5,7 +5,7 @@ use async_openai::types::{ ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequestArgs, ResponseFormat, ResponseFormatJsonSchema, }; -use axum::{response::IntoResponse, Extension, Json}; +use axum::{http::StatusCode, response::IntoResponse, Extension, Json}; use serde::Deserialize; use serde_json::json; use std::sync::Arc; @@ -16,6 +16,22 @@ pub struct QueryInput { query: String, } +#[derive(Debug, Deserialize)] +struct Reference { + reference: String, +} + +#[derive(Debug, Deserialize)] +struct LLMResponseFormat { + answer: String, + references: Vec, +} +// impl IntoResponse for LLMResponseFormat { +// fn into_response(self) -> axum::response::Response { +// (StatusCode::OK, Json(self)) +// } +// } + pub async fn query_handler( Extension(db_client): Extension>, Json(query): Json, @@ -127,12 +143,16 @@ pub async fn query_handler( .first() .and_then(|choice| choice.message.content.as_ref()) .ok_or(ApiError::QueryError( - "No content found in LLM response".to_string(), + "No content found in LLM response".into(), ))?; info!("{:?}", answer); - // info!("{:#?}", entities_json); + let parsed: LLMResponseFormat = serde_json::from_str(answer).map_err(|e| { + ApiError::QueryError(format!("Failed to parse LLM response into analysis: {}", e)) + })?; - Ok(answer.clone().into_response()) + info!("{:?}", parsed); + + Ok(parsed.answer.into_response()) }