use crate::connection_or_tx::ConnectionOrTx; use crate::db_context::DbContext; use crate::error::Error::GenericError; use crate::error::Result; use crate::util::ModelPayload; use r2d2::Pool; use r2d2_sqlite::SqliteConnectionManager; use rusqlite::TransactionBehavior; use std::sync::{Arc, Mutex}; use tauri::{Manager, Runtime, State}; use tokio::sync::mpsc; pub trait QueryManagerExt<'a, R> { fn db_manager(&'a self) -> State<'a, QueryManager>; fn db(&'a self) -> DbContext<'a>; fn with_db(&'a self, func: F) -> T where F: FnOnce(&DbContext) -> T; fn with_tx(&'a self, func: F) -> Result where F: FnOnce(&DbContext) -> Result; } impl<'a, R: Runtime, M: Manager> QueryManagerExt<'a, R> for M { fn db_manager(&'a self) -> State<'a, QueryManager> { self.state::() } fn db(&'a self) -> DbContext<'a> { let qm = self.state::(); qm.inner().connect() } fn with_db(&'a self, func: F) -> T where F: FnOnce(&DbContext) -> T, { let qm = self.state::(); qm.inner().with_conn(func) } fn with_tx(&'a self, func: F) -> Result where F: FnOnce(&DbContext) -> Result, { let qm = self.state::(); qm.inner().with_tx(func) } } #[derive(Debug, Clone)] pub struct QueryManager { pool: Arc>>, events_tx: mpsc::Sender, } impl QueryManager { pub(crate) fn new( pool: Pool, events_tx: mpsc::Sender, ) -> Self { QueryManager { pool: Arc::new(Mutex::new(pool)), events_tx, } } pub fn connect(&self) -> DbContext { let conn = self .pool .lock() .expect("Failed to gain lock on DB") .get() .expect("Failed to get a new DB connection from the pool"); DbContext { events_tx: self.events_tx.clone(), conn: ConnectionOrTx::Connection(conn), } } pub fn with_conn(&self, func: F) -> T where F: FnOnce(&DbContext) -> T, { let conn = self .pool .lock() .expect("Failed to gain lock on DB for transaction") .get() .expect("Failed to get new DB connection from the pool"); let db_context = DbContext { events_tx: self.events_tx.clone(), conn: ConnectionOrTx::Connection(conn), }; func(&db_context) } pub fn with_tx( &self, func: impl FnOnce(&DbContext) -> std::result::Result, ) -> std::result::Result where E: From, { let mut conn = self .pool .lock() .expect("Failed to gain lock on DB for transaction") .get() .expect("Failed to get new DB connection from the pool"); let tx = conn .transaction_with_behavior(TransactionBehavior::Immediate) .expect("Failed to start DB transaction"); let db_context = DbContext { events_tx: self.events_tx.clone(), conn: ConnectionOrTx::Transaction(&tx), }; match func(&db_context) { Ok(val) => { tx.commit() .map_err(|e| GenericError(format!("Failed to commit transaction {e:?}")))?; Ok(val) } Err(e) => { tx.rollback() .map_err(|e| GenericError(format!("Failed to rollback transaction {e:?}")))?; Err(e) } } } }