From 30bcf5dcc22e1bb1fb983a8d2887edc39404e6df Mon Sep 17 00:00:00 2001 From: Fabian-Lars Date: Tue, 1 Oct 2024 14:47:08 +0200 Subject: [PATCH] refactor(sql): Allow multiple drivers at the same time (#1838) * refactor(sql): Allow multiple drivers at the same time * fmt * remove default feature comment [skip ci] * what was that doing there [skip ci] * disable public methods for now --- .changes/feat-multiple-sql-backends.md | 5 + .github/workflows/lint-rust.yml | 10 +- .github/workflows/test-rust.yml | 14 +- plugins/localhost/src/lib.rs | 2 +- plugins/sql/Cargo.toml | 1 + plugins/sql/src/commands.rs | 82 ++++++ plugins/sql/src/decode/mod.rs | 15 +- plugins/sql/src/error.rs | 28 ++ plugins/sql/src/lib.rs | 180 +++++++++++-- plugins/sql/src/plugin.rs | 343 ------------------------- plugins/sql/src/wrapper.rs | 328 +++++++++++++++++++++++ 11 files changed, 615 insertions(+), 393 deletions(-) create mode 100644 .changes/feat-multiple-sql-backends.md create mode 100644 plugins/sql/src/commands.rs create mode 100644 plugins/sql/src/error.rs delete mode 100644 plugins/sql/src/plugin.rs create mode 100644 plugins/sql/src/wrapper.rs diff --git a/.changes/feat-multiple-sql-backends.md b/.changes/feat-multiple-sql-backends.md new file mode 100644 index 00000000..65b8fe86 --- /dev/null +++ b/.changes/feat-multiple-sql-backends.md @@ -0,0 +1,5 @@ +--- +sql: patch +--- + +It is now possible to enable multiple SQL backends at the same time. There will be no compile error anymore if no backends are enabled! diff --git a/.github/workflows/lint-rust.yml b/.github/workflows/lint-rust.yml index 74d9d766..39cc37fe 100644 --- a/.github/workflows/lint-rust.yml +++ b/.github/workflows/lint-rust.yml @@ -148,13 +148,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: clippy ${{ matrix.package }} - if: matrix.package != 'tauri-plugin-sql' run: cargo clippy --package ${{ matrix.package }} --all-targets -- -D warnings - - name: clippy ${{ matrix.package }} mysql - if: matrix.package == 'tauri-plugin-sql' - run: cargo clippy --package ${{ matrix.package }} --all-targets --no-default-features --features mysql -- -D warnings - - - name: clippy ${{ matrix.package }} postgres - if: matrix.package == 'tauri-plugin-sql' - run: cargo clippy --package ${{ matrix.package }} --all-targets --no-default-features --features postgres -- -D warnings + - name: clippy ${{ matrix.package }} --all-features + run: cargo clippy --package ${{ matrix.package }} --all-targets --all-features -- -D warnings diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 75b1e55f..34af5e2c 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -215,21 +215,9 @@ jobs: run: cargo +stable install cross --git https://github.com/cross-rs/cross - name: test ${{ matrix.package }} - if: matrix.package != 'tauri-plugin-sql' && matrix.package != 'tauri-plugin-http' + if: matrix.package != 'tauri-plugin-http' run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --all-features - name: test ${{ matrix.package }} if: matrix.package == 'tauri-plugin-http' run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets - - - name: test ${{ matrix.package }} sqlite - if: matrix.package == 'tauri-plugin-sql' - run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features sqlite - - - name: test ${{ matrix.package }} mysql - if: matrix.package == 'tauri-plugin-sql' - run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features mysql - - - name: test ${{ matrix.package }} postgres - if: matrix.package == 'tauri-plugin-sql' - run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features postgres diff --git a/plugins/localhost/src/lib.rs b/plugins/localhost/src/lib.rs index f5f99fe4..a0c4c794 100644 --- a/plugins/localhost/src/lib.rs +++ b/plugins/localhost/src/lib.rs @@ -74,7 +74,7 @@ impl Builder { let asset_resolver = app.asset_resolver(); std::thread::spawn(move || { let server = - Server::http(&format!("localhost:{port}")).expect("Unable to spawn server"); + Server::http(format!("localhost:{port}")).expect("Unable to spawn server"); for req in server.incoming_requests() { let path = req .url() diff --git a/plugins/sql/Cargo.toml b/plugins/sql/Cargo.toml index bfcfc99e..4f4db76f 100644 --- a/plugins/sql/Cargo.toml +++ b/plugins/sql/Cargo.toml @@ -40,3 +40,4 @@ indexmap = { version = "2", features = ["serde"] } sqlite = ["sqlx/sqlite", "sqlx/runtime-tokio"] mysql = ["sqlx/mysql", "sqlx/runtime-tokio-rustls"] postgres = ["sqlx/postgres", "sqlx/runtime-tokio-rustls"] +# TODO: bundled-cipher etc diff --git a/plugins/sql/src/commands.rs b/plugins/sql/src/commands.rs new file mode 100644 index 00000000..8cd90e9c --- /dev/null +++ b/plugins/sql/src/commands.rs @@ -0,0 +1,82 @@ +// Copyright 2019-2023 Tauri Programme within The Commons Conservancy +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: MIT + +use indexmap::IndexMap; +use serde_json::Value as JsonValue; +use sqlx::migrate::Migrator; +use tauri::{command, AppHandle, Runtime, State}; + +use crate::{DbInstances, DbPool, Error, LastInsertId, Migrations}; + +#[command] +pub(crate) async fn load( + app: AppHandle, + db_instances: State<'_, DbInstances>, + migrations: State<'_, Migrations>, + db: String, +) -> Result { + let pool = DbPool::connect(&db, &app).await?; + + if let Some(migrations) = migrations.0.lock().await.remove(&db) { + let migrator = Migrator::new(migrations).await?; + pool.migrate(&migrator).await?; + } + + db_instances.0.lock().await.insert(db.clone(), pool); + + Ok(db) +} + +/// Allows the database connection(s) to be closed; if no database +/// name is passed in then _all_ database connection pools will be +/// shut down. +#[command] +pub(crate) async fn close( + db_instances: State<'_, DbInstances>, + db: Option, +) -> Result { + let mut instances = db_instances.0.lock().await; + + let pools = if let Some(db) = db { + vec![db] + } else { + instances.keys().cloned().collect() + }; + + for pool in pools { + let db = instances + .get_mut(&pool) + .ok_or(Error::DatabaseNotLoaded(pool))?; + db.close().await; + } + + Ok(true) +} + +/// Execute a command against the database +#[command] +pub(crate) async fn execute( + db_instances: State<'_, DbInstances>, + db: String, + query: String, + values: Vec, +) -> Result<(u64, LastInsertId), crate::Error> { + let mut instances = db_instances.0.lock().await; + + let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?; + db.execute(query, values).await +} + +#[command] +pub(crate) async fn select( + db_instances: State<'_, DbInstances>, + db: String, + query: String, + values: Vec, +) -> Result>, crate::Error> { + let mut instances = db_instances.0.lock().await; + + let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?; + db.select(query, values).await +} diff --git a/plugins/sql/src/decode/mod.rs b/plugins/sql/src/decode/mod.rs index 50fb3c78..0a2d2cdd 100644 --- a/plugins/sql/src/decode/mod.rs +++ b/plugins/sql/src/decode/mod.rs @@ -3,17 +3,8 @@ // SPDX-License-Identifier: MIT #[cfg(feature = "mysql")] -mod mysql; +pub(crate) mod mysql; #[cfg(feature = "postgres")] -mod postgres; +pub(crate) mod postgres; #[cfg(feature = "sqlite")] -mod sqlite; - -#[cfg(feature = "mysql")] -pub(crate) use mysql::to_json; - -#[cfg(feature = "postgres")] -pub(crate) use postgres::to_json; - -#[cfg(feature = "sqlite")] -pub(crate) use sqlite::to_json; +pub(crate) mod sqlite; diff --git a/plugins/sql/src/error.rs b/plugins/sql/src/error.rs new file mode 100644 index 00000000..5ac845b8 --- /dev/null +++ b/plugins/sql/src/error.rs @@ -0,0 +1,28 @@ +// Copyright 2019-2023 Tauri Programme within The Commons Conservancy +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: MIT + +use serde::{Serialize, Serializer}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Sql(#[from] sqlx::Error), + #[error(transparent)] + Migration(#[from] sqlx::migrate::MigrateError), + #[error("invalid connection url: {0}")] + InvalidDbUrl(String), + #[error("database {0} not loaded")] + DatabaseNotLoaded(String), + #[error("unsupported datatype: {0}")] + UnsupportedDatatype(String), +} + +impl Serialize for Error { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + serializer.serialize_str(self.to_string().as_ref()) + } +} diff --git a/plugins/sql/src/lib.rs b/plugins/sql/src/lib.rs index f25ede21..ec9362bf 100644 --- a/plugins/sql/src/lib.rs +++ b/plugins/sql/src/lib.rs @@ -11,20 +11,168 @@ html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png" )] -#[cfg(any( - all(feature = "sqlite", feature = "mysql"), - all(feature = "sqlite", feature = "postgres"), - all(feature = "mysql", feature = "postgres") -))] -compile_error!( - "Only one database driver can be enabled. Set the feature flag for the driver of your choice." -); - -#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] -compile_error!( - "Database driver not defined. Please set the feature flag for the driver of your choice." -); - +mod commands; mod decode; -mod plugin; -pub use plugin::*; +mod error; +mod wrapper; + +pub use error::Error; +pub use wrapper::DbPool; + +use futures_core::future::BoxFuture; +use serde::{Deserialize, Serialize}; +use sqlx::{ + error::BoxDynError, + migrate::{Migration as SqlxMigration, MigrationSource, MigrationType, Migrator}, +}; +use tauri::{ + plugin::{Builder as PluginBuilder, TauriPlugin}, + Manager, RunEvent, Runtime, +}; +use tokio::sync::Mutex; + +use std::collections::HashMap; + +#[derive(Default)] +pub struct DbInstances(pub Mutex>); + +#[derive(Serialize)] +#[serde(untagged)] +pub(crate) enum LastInsertId { + #[cfg(feature = "sqlite")] + Sqlite(i64), + #[cfg(feature = "mysql")] + MySql(u64), + #[cfg(feature = "postgres")] + Postgres(()), + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] + None, +} + +struct Migrations(Mutex>); + +#[derive(Default, Clone, Deserialize)] +pub struct PluginConfig { + #[serde(default)] + preload: Vec, +} + +#[derive(Debug)] +pub enum MigrationKind { + Up, + Down, +} + +impl From for MigrationType { + fn from(kind: MigrationKind) -> Self { + match kind { + MigrationKind::Up => Self::ReversibleUp, + MigrationKind::Down => Self::ReversibleDown, + } + } +} + +/// A migration definition. +#[derive(Debug)] +pub struct Migration { + pub version: i64, + pub description: &'static str, + pub sql: &'static str, + pub kind: MigrationKind, +} + +#[derive(Debug)] +struct MigrationList(Vec); + +impl MigrationSource<'static> for MigrationList { + fn resolve(self) -> BoxFuture<'static, std::result::Result, BoxDynError>> { + Box::pin(async move { + let mut migrations = Vec::new(); + for migration in self.0 { + if matches!(migration.kind, MigrationKind::Up) { + migrations.push(SqlxMigration::new( + migration.version, + migration.description.into(), + migration.kind.into(), + migration.sql.into(), + false, + )); + } + } + Ok(migrations) + }) + } +} + +/// Tauri SQL plugin builder. +#[derive(Default)] +pub struct Builder { + migrations: Option>, +} + +impl Builder { + pub fn new() -> Self { + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] + eprintln!("No sql driver enabled. Please set at least one of the \"sqlite\", \"mysql\", \"postgres\" feature flags."); + + Self::default() + } + + /// Add migrations to a database. + #[must_use] + pub fn add_migrations(mut self, db_url: &str, migrations: Vec) -> Self { + self.migrations + .get_or_insert(Default::default()) + .insert(db_url.to_string(), MigrationList(migrations)); + self + } + + pub fn build(mut self) -> TauriPlugin> { + PluginBuilder::>::new("sql") + .invoke_handler(tauri::generate_handler![ + commands::load, + commands::execute, + commands::select, + commands::close + ]) + .setup(|app, api| { + let config = api.config().clone().unwrap_or_default(); + + tauri::async_runtime::block_on(async move { + let instances = DbInstances::default(); + let mut lock = instances.0.lock().await; + + for db in config.preload { + let pool = DbPool::connect(&db, app).await?; + + if let Some(migrations) = self.migrations.as_mut().unwrap().remove(&db) { + let migrator = Migrator::new(migrations).await?; + pool.migrate(&migrator).await?; + } + + lock.insert(db, pool); + } + drop(lock); + + app.manage(instances); + app.manage(Migrations(Mutex::new( + self.migrations.take().unwrap_or_default(), + ))); + + Ok(()) + }) + }) + .on_event(|app, event| { + if let RunEvent::Exit = event { + tauri::async_runtime::block_on(async move { + let instances = &*app.state::(); + let instances = instances.0.lock().await; + for value in instances.values() { + value.close().await; + } + }); + } + }) + .build() + } +} diff --git a/plugins/sql/src/plugin.rs b/plugins/sql/src/plugin.rs deleted file mode 100644 index 63f8e183..00000000 --- a/plugins/sql/src/plugin.rs +++ /dev/null @@ -1,343 +0,0 @@ -// Copyright 2019-2023 Tauri Programme within The Commons Conservancy -// SPDX-License-Identifier: Apache-2.0 -// SPDX-License-Identifier: MIT - -use futures_core::future::BoxFuture; -use serde::{ser::Serializer, Deserialize, Serialize}; -use serde_json::Value as JsonValue; -use sqlx::{ - error::BoxDynError, - migrate::{ - MigrateDatabase, Migration as SqlxMigration, MigrationSource, MigrationType, Migrator, - }, - Column, Pool, Row, -}; -use tauri::{ - command, - plugin::{Builder as PluginBuilder, TauriPlugin}, - AppHandle, Manager, RunEvent, Runtime, State, -}; -use tokio::sync::Mutex; - -use indexmap::IndexMap; -use std::collections::HashMap; - -#[cfg(feature = "sqlite")] -use std::{fs::create_dir_all, path::PathBuf}; - -#[cfg(feature = "sqlite")] -type Db = sqlx::sqlite::Sqlite; -#[cfg(feature = "mysql")] -type Db = sqlx::mysql::MySql; -#[cfg(feature = "postgres")] -type Db = sqlx::postgres::Postgres; - -#[cfg(feature = "sqlite")] -type LastInsertId = i64; -#[cfg(not(feature = "sqlite"))] -type LastInsertId = u64; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error(transparent)] - Sql(#[from] sqlx::Error), - #[error(transparent)] - Migration(#[from] sqlx::migrate::MigrateError), - #[error("database {0} not loaded")] - DatabaseNotLoaded(String), - #[error("unsupported datatype: {0}")] - UnsupportedDatatype(String), -} - -impl Serialize for Error { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - serializer.serialize_str(self.to_string().as_ref()) - } -} - -type Result = std::result::Result; - -#[cfg(feature = "sqlite")] -/// Resolves the App's **file path** from the `AppHandle` context -/// object -fn app_path(app: &AppHandle) -> PathBuf { - app.path().app_config_dir().expect("No App path was found!") -} - -#[cfg(feature = "sqlite")] -/// Maps the user supplied DB connection string to a connection string -/// with a fully qualified file path to the App's designed "app_path" -fn path_mapper(mut app_path: PathBuf, connection_string: &str) -> String { - app_path.push( - connection_string - .split_once(':') - .expect("Couldn't parse the connection string for DB!") - .1, - ); - - format!( - "sqlite:{}", - app_path - .to_str() - .expect("Problem creating fully qualified path to Database file!") - ) -} - -#[derive(Default)] -pub struct DbInstances(pub Mutex>>); - -struct Migrations(Mutex>); - -#[derive(Default, Clone, Deserialize)] -pub struct PluginConfig { - #[serde(default)] - preload: Vec, -} - -#[derive(Debug)] -pub enum MigrationKind { - Up, - Down, -} - -impl From for MigrationType { - fn from(kind: MigrationKind) -> Self { - match kind { - MigrationKind::Up => Self::ReversibleUp, - MigrationKind::Down => Self::ReversibleDown, - } - } -} - -/// A migration definition. -#[derive(Debug)] -pub struct Migration { - pub version: i64, - pub description: &'static str, - pub sql: &'static str, - pub kind: MigrationKind, -} - -#[derive(Debug)] -struct MigrationList(Vec); - -impl MigrationSource<'static> for MigrationList { - fn resolve(self) -> BoxFuture<'static, std::result::Result, BoxDynError>> { - Box::pin(async move { - let mut migrations = Vec::new(); - for migration in self.0 { - if matches!(migration.kind, MigrationKind::Up) { - migrations.push(SqlxMigration::new( - migration.version, - migration.description.into(), - migration.kind.into(), - migration.sql.into(), - false, - )); - } - } - Ok(migrations) - }) - } -} - -#[command] -async fn load( - #[allow(unused_variables)] app: AppHandle, - db_instances: State<'_, DbInstances>, - migrations: State<'_, Migrations>, - db: String, -) -> Result { - #[cfg(feature = "sqlite")] - let fqdb = path_mapper(app_path(&app), &db); - #[cfg(not(feature = "sqlite"))] - let fqdb = db.clone(); - - #[cfg(feature = "sqlite")] - create_dir_all(app_path(&app)).expect("Problem creating App directory!"); - - if !Db::database_exists(&fqdb).await.unwrap_or(false) { - Db::create_database(&fqdb).await?; - } - let pool = Pool::connect(&fqdb).await?; - - if let Some(migrations) = migrations.0.lock().await.remove(&db) { - let migrator = Migrator::new(migrations).await?; - migrator.run(&pool).await?; - } - - db_instances.0.lock().await.insert(db.clone(), pool); - Ok(db) -} - -/// Allows the database connection(s) to be closed; if no database -/// name is passed in then _all_ database connection pools will be -/// shut down. -#[command] -async fn close(db_instances: State<'_, DbInstances>, db: Option) -> Result { - let mut instances = db_instances.0.lock().await; - - let pools = if let Some(db) = db { - vec![db] - } else { - instances.keys().cloned().collect() - }; - - for pool in pools { - let db = instances - .get_mut(&pool) // - .ok_or(Error::DatabaseNotLoaded(pool))?; - db.close().await; - } - - Ok(true) -} - -/// Execute a command against the database -#[command] -async fn execute( - db_instances: State<'_, DbInstances>, - db: String, - query: String, - values: Vec, -) -> Result<(u64, LastInsertId)> { - let mut instances = db_instances.0.lock().await; - - let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?; - let mut query = sqlx::query(&query); - for value in values { - if value.is_null() { - query = query.bind(None::); - } else if value.is_string() { - query = query.bind(value.as_str().unwrap().to_owned()) - } else if let Some(number) = value.as_number() { - query = query.bind(number.as_f64().unwrap_or_default()) - } else { - query = query.bind(value); - } - } - let result = query.execute(&*db).await?; - #[cfg(feature = "sqlite")] - let r = Ok((result.rows_affected(), result.last_insert_rowid())); - #[cfg(feature = "mysql")] - let r = Ok((result.rows_affected(), result.last_insert_id())); - #[cfg(feature = "postgres")] - let r = Ok((result.rows_affected(), 0)); - r -} - -#[command] -async fn select( - db_instances: State<'_, DbInstances>, - db: String, - query: String, - values: Vec, -) -> Result>> { - let mut instances = db_instances.0.lock().await; - let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?; - let mut query = sqlx::query(&query); - for value in values { - if value.is_null() { - query = query.bind(None::); - } else if value.is_string() { - query = query.bind(value.as_str().unwrap().to_owned()) - } else if let Some(number) = value.as_number() { - query = query.bind(number.as_f64().unwrap_or_default()) - } else { - query = query.bind(value); - } - } - let rows = query.fetch_all(&*db).await?; - let mut values = Vec::new(); - for row in rows { - let mut value = IndexMap::default(); - for (i, column) in row.columns().iter().enumerate() { - let v = row.try_get_raw(i)?; - - let v = crate::decode::to_json(v)?; - - value.insert(column.name().to_string(), v); - } - - values.push(value); - } - - Ok(values) -} - -/// Tauri SQL plugin builder. -#[derive(Default)] -pub struct Builder { - migrations: Option>, -} - -impl Builder { - pub fn new() -> Self { - Self::default() - } - - /// Add migrations to a database. - #[must_use] - pub fn add_migrations(mut self, db_url: &str, migrations: Vec) -> Self { - self.migrations - .get_or_insert(Default::default()) - .insert(db_url.to_string(), MigrationList(migrations)); - self - } - - pub fn build(mut self) -> TauriPlugin> { - PluginBuilder::>::new("sql") - .invoke_handler(tauri::generate_handler![load, execute, select, close]) - .setup(|app, api| { - let config = api.config().clone().unwrap_or_default(); - - #[cfg(feature = "sqlite")] - create_dir_all(app_path(app)).expect("problems creating App directory!"); - - tauri::async_runtime::block_on(async move { - let instances = DbInstances::default(); - let mut lock = instances.0.lock().await; - for db in config.preload { - #[cfg(feature = "sqlite")] - let fqdb = path_mapper(app_path(app), &db); - #[cfg(not(feature = "sqlite"))] - let fqdb = db.clone(); - - if !Db::database_exists(&fqdb).await.unwrap_or(false) { - Db::create_database(&fqdb).await?; - } - let pool = Pool::connect(&fqdb).await?; - - if let Some(migrations) = self.migrations.as_mut().unwrap().remove(&db) { - let migrator = Migrator::new(migrations).await?; - migrator.run(&pool).await?; - } - lock.insert(db, pool); - } - drop(lock); - - app.manage(instances); - app.manage(Migrations(Mutex::new( - self.migrations.take().unwrap_or_default(), - ))); - - Ok(()) - }) - }) - .on_event(|app, event| { - if let RunEvent::Exit = event { - tauri::async_runtime::block_on(async move { - let instances = &*app.state::(); - let instances = instances.0.lock().await; - for value in instances.values() { - value.close().await; - } - }); - } - }) - .build() - } -} diff --git a/plugins/sql/src/wrapper.rs b/plugins/sql/src/wrapper.rs new file mode 100644 index 00000000..90631dac --- /dev/null +++ b/plugins/sql/src/wrapper.rs @@ -0,0 +1,328 @@ +// Copyright 2019-2023 Tauri Programme within The Commons Conservancy +// SPDX-License-Identifier: Apache-2.0 +// SPDX-License-Identifier: MIT + +#[cfg(feature = "sqlite")] +use std::fs::create_dir_all; + +use indexmap::IndexMap; +use serde_json::Value as JsonValue; +#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))] +use sqlx::{migrate::MigrateDatabase, Column, Executor, Pool, Row}; +#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))] +use tauri::Manager; +use tauri::{AppHandle, Runtime}; + +#[cfg(feature = "mysql")] +use sqlx::MySql; +#[cfg(feature = "postgres")] +use sqlx::Postgres; +#[cfg(feature = "sqlite")] +use sqlx::Sqlite; + +use crate::LastInsertId; + +pub enum DbPool { + #[cfg(feature = "sqlite")] + Sqlite(Pool), + #[cfg(feature = "mysql")] + MySql(Pool), + #[cfg(feature = "postgres")] + Postgres(Pool), + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] + None, +} + +// public methods +/* impl DbPool { + /// Get the inner Sqlite Pool. Returns None for MySql and Postgres pools. + #[cfg(feature = "sqlite")] + pub fn sqlite(&self) -> Option<&Pool> { + match self { + DbPool::Sqlite(pool) => Some(pool), + _ => None, + } + } + + /// Get the inner MySql Pool. Returns None for Sqlite and Postgres pools. + #[cfg(feature = "mysql")] + pub fn mysql(&self) -> Option<&Pool> { + match self { + DbPool::MySql(pool) => Some(pool), + _ => None, + } + } + + /// Get the inner Postgres Pool. Returns None for MySql and Sqlite pools. + #[cfg(feature = "postgres")] + pub fn postgres(&self) -> Option<&Pool> { + match self { + DbPool::Postgres(pool) => Some(pool), + _ => None, + } + } +} */ + +// private methods +impl DbPool { + pub(crate) async fn connect( + conn_url: &str, + _app: &AppHandle, + ) -> Result { + match conn_url + .split_once(':') + .ok_or_else(|| crate::Error::InvalidDbUrl(conn_url.to_string()))? + .0 + { + #[cfg(feature = "sqlite")] + "sqlite" => { + let app_path = _app + .path() + .app_config_dir() + .expect("No App config path was found!"); + + create_dir_all(&app_path).expect("Couldn't create app config dir"); + + let conn_url = &path_mapper(app_path, conn_url); + + if !Sqlite::database_exists(conn_url).await.unwrap_or(false) { + Sqlite::create_database(conn_url).await?; + } + Ok(Self::Sqlite(Pool::connect(conn_url).await?)) + } + #[cfg(feature = "mysql")] + "mysql" => { + if !MySql::database_exists(conn_url).await.unwrap_or(false) { + MySql::create_database(conn_url).await?; + } + Ok(Self::MySql(Pool::connect(conn_url).await?)) + } + #[cfg(feature = "postgres")] + "postgres" => { + if !Postgres::database_exists(conn_url).await.unwrap_or(false) { + Postgres::create_database(conn_url).await?; + } + Ok(Self::Postgres(Pool::connect(conn_url).await?)) + } + _ => Err(crate::Error::InvalidDbUrl(conn_url.to_string())), + } + } + + pub(crate) async fn migrate( + &self, + _migrator: &sqlx::migrate::Migrator, + ) -> Result<(), crate::Error> { + match self { + #[cfg(feature = "sqlite")] + DbPool::Sqlite(pool) => _migrator.run(pool).await?, + #[cfg(feature = "mysql")] + DbPool::MySql(pool) => _migrator.run(pool).await?, + #[cfg(feature = "postgres")] + DbPool::Postgres(pool) => _migrator.run(pool).await?, + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] + DbPool::None => (), + } + Ok(()) + } + + pub(crate) async fn close(&self) { + match self { + #[cfg(feature = "sqlite")] + DbPool::Sqlite(pool) => pool.close().await, + #[cfg(feature = "mysql")] + DbPool::MySql(pool) => pool.close().await, + #[cfg(feature = "postgres")] + DbPool::Postgres(pool) => pool.close().await, + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] + DbPool::None => (), + } + } + + pub(crate) async fn execute( + &self, + _query: String, + _values: Vec, + ) -> Result<(u64, LastInsertId), crate::Error> { + Ok(match self { + #[cfg(feature = "sqlite")] + DbPool::Sqlite(pool) => { + let mut query = sqlx::query(&_query); + for value in _values { + if value.is_null() { + query = query.bind(None::); + } else if value.is_string() { + query = query.bind(value.as_str().unwrap().to_owned()) + } else if let Some(number) = value.as_number() { + query = query.bind(number.as_f64().unwrap_or_default()) + } else { + query = query.bind(value); + } + } + let result = pool.execute(query).await?; + ( + result.rows_affected(), + LastInsertId::Sqlite(result.last_insert_rowid()), + ) + } + #[cfg(feature = "mysql")] + DbPool::MySql(pool) => { + let mut query = sqlx::query(&_query); + for value in _values { + if value.is_null() { + query = query.bind(None::); + } else if value.is_string() { + query = query.bind(value.as_str().unwrap().to_owned()) + } else if let Some(number) = value.as_number() { + query = query.bind(number.as_f64().unwrap_or_default()) + } else { + query = query.bind(value); + } + } + let result = pool.execute(query).await?; + ( + result.rows_affected(), + LastInsertId::MySql(result.last_insert_id()), + ) + } + #[cfg(feature = "postgres")] + DbPool::Postgres(pool) => { + let mut query = sqlx::query(&_query); + for value in _values { + if value.is_null() { + query = query.bind(None::); + } else if value.is_string() { + query = query.bind(value.as_str().unwrap().to_owned()) + } else if let Some(number) = value.as_number() { + query = query.bind(number.as_f64().unwrap_or_default()) + } else { + query = query.bind(value); + } + } + let result = pool.execute(query).await?; + (result.rows_affected(), LastInsertId::Postgres(())) + } + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] + DbPool::None => (0, LastInsertId::None), + }) + } + + pub(crate) async fn select( + &self, + _query: String, + _values: Vec, + ) -> Result>, crate::Error> { + Ok(match self { + #[cfg(feature = "sqlite")] + DbPool::Sqlite(pool) => { + let mut query = sqlx::query(&_query); + for value in _values { + if value.is_null() { + query = query.bind(None::); + } else if value.is_string() { + query = query.bind(value.as_str().unwrap().to_owned()) + } else if let Some(number) = value.as_number() { + query = query.bind(number.as_f64().unwrap_or_default()) + } else { + query = query.bind(value); + } + } + let rows = pool.fetch_all(query).await?; + let mut values = Vec::new(); + for row in rows { + let mut value = IndexMap::default(); + for (i, column) in row.columns().iter().enumerate() { + let v = row.try_get_raw(i)?; + + let v = crate::decode::sqlite::to_json(v)?; + + value.insert(column.name().to_string(), v); + } + + values.push(value); + } + values + } + #[cfg(feature = "mysql")] + DbPool::MySql(pool) => { + let mut query = sqlx::query(&_query); + for value in _values { + if value.is_null() { + query = query.bind(None::); + } else if value.is_string() { + query = query.bind(value.as_str().unwrap().to_owned()) + } else if let Some(number) = value.as_number() { + query = query.bind(number.as_f64().unwrap_or_default()) + } else { + query = query.bind(value); + } + } + let rows = pool.fetch_all(query).await?; + let mut values = Vec::new(); + for row in rows { + let mut value = IndexMap::default(); + for (i, column) in row.columns().iter().enumerate() { + let v = row.try_get_raw(i)?; + + let v = crate::decode::mysql::to_json(v)?; + + value.insert(column.name().to_string(), v); + } + + values.push(value); + } + values + } + #[cfg(feature = "postgres")] + DbPool::Postgres(pool) => { + let mut query = sqlx::query(&_query); + for value in _values { + if value.is_null() { + query = query.bind(None::); + } else if value.is_string() { + query = query.bind(value.as_str().unwrap().to_owned()) + } else if let Some(number) = value.as_number() { + query = query.bind(number.as_f64().unwrap_or_default()) + } else { + query = query.bind(value); + } + } + let rows = pool.fetch_all(query).await?; + let mut values = Vec::new(); + for row in rows { + let mut value = IndexMap::default(); + for (i, column) in row.columns().iter().enumerate() { + let v = row.try_get_raw(i)?; + + let v = crate::decode::postgres::to_json(v)?; + + value.insert(column.name().to_string(), v); + } + + values.push(value); + } + values + } + #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))] + DbPool::None => Vec::new(), + }) + } +} + +#[cfg(feature = "sqlite")] +/// Maps the user supplied DB connection string to a connection string +/// with a fully qualified file path to the App's designed "app_path" +fn path_mapper(mut app_path: std::path::PathBuf, connection_string: &str) -> String { + app_path.push( + connection_string + .split_once(':') + .expect("Couldn't parse the connection string for DB!") + .1, + ); + + format!( + "sqlite:{}", + app_path + .to_str() + .expect("Problem creating fully qualified path to Database file!") + ) +}