refactor(sql): Allow multiple drivers at the same time (#1838)

* refactor(sql): Allow multiple drivers at the same time

* fmt

* remove default feature comment [skip ci]

* what was that doing there [skip ci]

* disable public methods for now
pull/1822/head
Fabian-Lars 8 months ago committed by GitHub
parent 68579934c9
commit 30bcf5dcc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,5 @@
---
sql: patch
---
It is now possible to enable multiple SQL backends at the same time. There will be no compile error anymore if no backends are enabled!

@ -148,13 +148,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
- name: clippy ${{ matrix.package }}
if: matrix.package != 'tauri-plugin-sql'
run: cargo clippy --package ${{ matrix.package }} --all-targets -- -D warnings
- name: clippy ${{ matrix.package }} mysql
if: matrix.package == 'tauri-plugin-sql'
run: cargo clippy --package ${{ matrix.package }} --all-targets --no-default-features --features mysql -- -D warnings
- name: clippy ${{ matrix.package }} postgres
if: matrix.package == 'tauri-plugin-sql'
run: cargo clippy --package ${{ matrix.package }} --all-targets --no-default-features --features postgres -- -D warnings
- name: clippy ${{ matrix.package }} --all-features
run: cargo clippy --package ${{ matrix.package }} --all-targets --all-features -- -D warnings

@ -215,21 +215,9 @@ jobs:
run: cargo +stable install cross --git https://github.com/cross-rs/cross
- name: test ${{ matrix.package }}
if: matrix.package != 'tauri-plugin-sql' && matrix.package != 'tauri-plugin-http'
if: matrix.package != 'tauri-plugin-http'
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --all-features
- name: test ${{ matrix.package }}
if: matrix.package == 'tauri-plugin-http'
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets
- name: test ${{ matrix.package }} sqlite
if: matrix.package == 'tauri-plugin-sql'
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features sqlite
- name: test ${{ matrix.package }} mysql
if: matrix.package == 'tauri-plugin-sql'
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features mysql
- name: test ${{ matrix.package }} postgres
if: matrix.package == 'tauri-plugin-sql'
run: ${{ matrix.platform.runner }} ${{ matrix.platform.command }} --package ${{ matrix.package }} --target ${{ matrix.platform.target }} --all-targets --features postgres

@ -74,7 +74,7 @@ impl Builder {
let asset_resolver = app.asset_resolver();
std::thread::spawn(move || {
let server =
Server::http(&format!("localhost:{port}")).expect("Unable to spawn server");
Server::http(format!("localhost:{port}")).expect("Unable to spawn server");
for req in server.incoming_requests() {
let path = req
.url()

@ -40,3 +40,4 @@ indexmap = { version = "2", features = ["serde"] }
sqlite = ["sqlx/sqlite", "sqlx/runtime-tokio"]
mysql = ["sqlx/mysql", "sqlx/runtime-tokio-rustls"]
postgres = ["sqlx/postgres", "sqlx/runtime-tokio-rustls"]
# TODO: bundled-cipher etc

@ -0,0 +1,82 @@
// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT
use indexmap::IndexMap;
use serde_json::Value as JsonValue;
use sqlx::migrate::Migrator;
use tauri::{command, AppHandle, Runtime, State};
use crate::{DbInstances, DbPool, Error, LastInsertId, Migrations};
#[command]
pub(crate) async fn load<R: Runtime>(
app: AppHandle<R>,
db_instances: State<'_, DbInstances>,
migrations: State<'_, Migrations>,
db: String,
) -> Result<String, crate::Error> {
let pool = DbPool::connect(&db, &app).await?;
if let Some(migrations) = migrations.0.lock().await.remove(&db) {
let migrator = Migrator::new(migrations).await?;
pool.migrate(&migrator).await?;
}
db_instances.0.lock().await.insert(db.clone(), pool);
Ok(db)
}
/// Allows the database connection(s) to be closed; if no database
/// name is passed in then _all_ database connection pools will be
/// shut down.
#[command]
pub(crate) async fn close(
db_instances: State<'_, DbInstances>,
db: Option<String>,
) -> Result<bool, crate::Error> {
let mut instances = db_instances.0.lock().await;
let pools = if let Some(db) = db {
vec![db]
} else {
instances.keys().cloned().collect()
};
for pool in pools {
let db = instances
.get_mut(&pool)
.ok_or(Error::DatabaseNotLoaded(pool))?;
db.close().await;
}
Ok(true)
}
/// Execute a command against the database
#[command]
pub(crate) async fn execute(
db_instances: State<'_, DbInstances>,
db: String,
query: String,
values: Vec<JsonValue>,
) -> Result<(u64, LastInsertId), crate::Error> {
let mut instances = db_instances.0.lock().await;
let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?;
db.execute(query, values).await
}
#[command]
pub(crate) async fn select(
db_instances: State<'_, DbInstances>,
db: String,
query: String,
values: Vec<JsonValue>,
) -> Result<Vec<IndexMap<String, JsonValue>>, crate::Error> {
let mut instances = db_instances.0.lock().await;
let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?;
db.select(query, values).await
}

@ -3,17 +3,8 @@
// SPDX-License-Identifier: MIT
#[cfg(feature = "mysql")]
mod mysql;
pub(crate) mod mysql;
#[cfg(feature = "postgres")]
mod postgres;
pub(crate) mod postgres;
#[cfg(feature = "sqlite")]
mod sqlite;
#[cfg(feature = "mysql")]
pub(crate) use mysql::to_json;
#[cfg(feature = "postgres")]
pub(crate) use postgres::to_json;
#[cfg(feature = "sqlite")]
pub(crate) use sqlite::to_json;
pub(crate) mod sqlite;

@ -0,0 +1,28 @@
// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT
use serde::{Serialize, Serializer};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Sql(#[from] sqlx::Error),
#[error(transparent)]
Migration(#[from] sqlx::migrate::MigrateError),
#[error("invalid connection url: {0}")]
InvalidDbUrl(String),
#[error("database {0} not loaded")]
DatabaseNotLoaded(String),
#[error("unsupported datatype: {0}")]
UnsupportedDatatype(String),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}

@ -11,20 +11,168 @@
html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
)]
#[cfg(any(
all(feature = "sqlite", feature = "mysql"),
all(feature = "sqlite", feature = "postgres"),
all(feature = "mysql", feature = "postgres")
))]
compile_error!(
"Only one database driver can be enabled. Set the feature flag for the driver of your choice."
);
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
compile_error!(
"Database driver not defined. Please set the feature flag for the driver of your choice."
);
mod commands;
mod decode;
mod plugin;
pub use plugin::*;
mod error;
mod wrapper;
pub use error::Error;
pub use wrapper::DbPool;
use futures_core::future::BoxFuture;
use serde::{Deserialize, Serialize};
use sqlx::{
error::BoxDynError,
migrate::{Migration as SqlxMigration, MigrationSource, MigrationType, Migrator},
};
use tauri::{
plugin::{Builder as PluginBuilder, TauriPlugin},
Manager, RunEvent, Runtime,
};
use tokio::sync::Mutex;
use std::collections::HashMap;
#[derive(Default)]
pub struct DbInstances(pub Mutex<HashMap<String, DbPool>>);
#[derive(Serialize)]
#[serde(untagged)]
pub(crate) enum LastInsertId {
#[cfg(feature = "sqlite")]
Sqlite(i64),
#[cfg(feature = "mysql")]
MySql(u64),
#[cfg(feature = "postgres")]
Postgres(()),
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
None,
}
struct Migrations(Mutex<HashMap<String, MigrationList>>);
#[derive(Default, Clone, Deserialize)]
pub struct PluginConfig {
#[serde(default)]
preload: Vec<String>,
}
#[derive(Debug)]
pub enum MigrationKind {
Up,
Down,
}
impl From<MigrationKind> for MigrationType {
fn from(kind: MigrationKind) -> Self {
match kind {
MigrationKind::Up => Self::ReversibleUp,
MigrationKind::Down => Self::ReversibleDown,
}
}
}
/// A migration definition.
#[derive(Debug)]
pub struct Migration {
pub version: i64,
pub description: &'static str,
pub sql: &'static str,
pub kind: MigrationKind,
}
#[derive(Debug)]
struct MigrationList(Vec<Migration>);
impl MigrationSource<'static> for MigrationList {
fn resolve(self) -> BoxFuture<'static, std::result::Result<Vec<SqlxMigration>, BoxDynError>> {
Box::pin(async move {
let mut migrations = Vec::new();
for migration in self.0 {
if matches!(migration.kind, MigrationKind::Up) {
migrations.push(SqlxMigration::new(
migration.version,
migration.description.into(),
migration.kind.into(),
migration.sql.into(),
false,
));
}
}
Ok(migrations)
})
}
}
/// Tauri SQL plugin builder.
#[derive(Default)]
pub struct Builder {
migrations: Option<HashMap<String, MigrationList>>,
}
impl Builder {
pub fn new() -> Self {
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
eprintln!("No sql driver enabled. Please set at least one of the \"sqlite\", \"mysql\", \"postgres\" feature flags.");
Self::default()
}
/// Add migrations to a database.
#[must_use]
pub fn add_migrations(mut self, db_url: &str, migrations: Vec<Migration>) -> Self {
self.migrations
.get_or_insert(Default::default())
.insert(db_url.to_string(), MigrationList(migrations));
self
}
pub fn build<R: Runtime>(mut self) -> TauriPlugin<R, Option<PluginConfig>> {
PluginBuilder::<R, Option<PluginConfig>>::new("sql")
.invoke_handler(tauri::generate_handler![
commands::load,
commands::execute,
commands::select,
commands::close
])
.setup(|app, api| {
let config = api.config().clone().unwrap_or_default();
tauri::async_runtime::block_on(async move {
let instances = DbInstances::default();
let mut lock = instances.0.lock().await;
for db in config.preload {
let pool = DbPool::connect(&db, app).await?;
if let Some(migrations) = self.migrations.as_mut().unwrap().remove(&db) {
let migrator = Migrator::new(migrations).await?;
pool.migrate(&migrator).await?;
}
lock.insert(db, pool);
}
drop(lock);
app.manage(instances);
app.manage(Migrations(Mutex::new(
self.migrations.take().unwrap_or_default(),
)));
Ok(())
})
})
.on_event(|app, event| {
if let RunEvent::Exit = event {
tauri::async_runtime::block_on(async move {
let instances = &*app.state::<DbInstances>();
let instances = instances.0.lock().await;
for value in instances.values() {
value.close().await;
}
});
}
})
.build()
}
}

@ -1,343 +0,0 @@
// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT
use futures_core::future::BoxFuture;
use serde::{ser::Serializer, Deserialize, Serialize};
use serde_json::Value as JsonValue;
use sqlx::{
error::BoxDynError,
migrate::{
MigrateDatabase, Migration as SqlxMigration, MigrationSource, MigrationType, Migrator,
},
Column, Pool, Row,
};
use tauri::{
command,
plugin::{Builder as PluginBuilder, TauriPlugin},
AppHandle, Manager, RunEvent, Runtime, State,
};
use tokio::sync::Mutex;
use indexmap::IndexMap;
use std::collections::HashMap;
#[cfg(feature = "sqlite")]
use std::{fs::create_dir_all, path::PathBuf};
#[cfg(feature = "sqlite")]
type Db = sqlx::sqlite::Sqlite;
#[cfg(feature = "mysql")]
type Db = sqlx::mysql::MySql;
#[cfg(feature = "postgres")]
type Db = sqlx::postgres::Postgres;
#[cfg(feature = "sqlite")]
type LastInsertId = i64;
#[cfg(not(feature = "sqlite"))]
type LastInsertId = u64;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Sql(#[from] sqlx::Error),
#[error(transparent)]
Migration(#[from] sqlx::migrate::MigrateError),
#[error("database {0} not loaded")]
DatabaseNotLoaded(String),
#[error("unsupported datatype: {0}")]
UnsupportedDatatype(String),
}
impl Serialize for Error {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.to_string().as_ref())
}
}
type Result<T> = std::result::Result<T, Error>;
#[cfg(feature = "sqlite")]
/// Resolves the App's **file path** from the `AppHandle` context
/// object
fn app_path<R: Runtime>(app: &AppHandle<R>) -> PathBuf {
app.path().app_config_dir().expect("No App path was found!")
}
#[cfg(feature = "sqlite")]
/// Maps the user supplied DB connection string to a connection string
/// with a fully qualified file path to the App's designed "app_path"
fn path_mapper(mut app_path: PathBuf, connection_string: &str) -> String {
app_path.push(
connection_string
.split_once(':')
.expect("Couldn't parse the connection string for DB!")
.1,
);
format!(
"sqlite:{}",
app_path
.to_str()
.expect("Problem creating fully qualified path to Database file!")
)
}
#[derive(Default)]
pub struct DbInstances(pub Mutex<HashMap<String, Pool<Db>>>);
struct Migrations(Mutex<HashMap<String, MigrationList>>);
#[derive(Default, Clone, Deserialize)]
pub struct PluginConfig {
#[serde(default)]
preload: Vec<String>,
}
#[derive(Debug)]
pub enum MigrationKind {
Up,
Down,
}
impl From<MigrationKind> for MigrationType {
fn from(kind: MigrationKind) -> Self {
match kind {
MigrationKind::Up => Self::ReversibleUp,
MigrationKind::Down => Self::ReversibleDown,
}
}
}
/// A migration definition.
#[derive(Debug)]
pub struct Migration {
pub version: i64,
pub description: &'static str,
pub sql: &'static str,
pub kind: MigrationKind,
}
#[derive(Debug)]
struct MigrationList(Vec<Migration>);
impl MigrationSource<'static> for MigrationList {
fn resolve(self) -> BoxFuture<'static, std::result::Result<Vec<SqlxMigration>, BoxDynError>> {
Box::pin(async move {
let mut migrations = Vec::new();
for migration in self.0 {
if matches!(migration.kind, MigrationKind::Up) {
migrations.push(SqlxMigration::new(
migration.version,
migration.description.into(),
migration.kind.into(),
migration.sql.into(),
false,
));
}
}
Ok(migrations)
})
}
}
#[command]
async fn load<R: Runtime>(
#[allow(unused_variables)] app: AppHandle<R>,
db_instances: State<'_, DbInstances>,
migrations: State<'_, Migrations>,
db: String,
) -> Result<String> {
#[cfg(feature = "sqlite")]
let fqdb = path_mapper(app_path(&app), &db);
#[cfg(not(feature = "sqlite"))]
let fqdb = db.clone();
#[cfg(feature = "sqlite")]
create_dir_all(app_path(&app)).expect("Problem creating App directory!");
if !Db::database_exists(&fqdb).await.unwrap_or(false) {
Db::create_database(&fqdb).await?;
}
let pool = Pool::connect(&fqdb).await?;
if let Some(migrations) = migrations.0.lock().await.remove(&db) {
let migrator = Migrator::new(migrations).await?;
migrator.run(&pool).await?;
}
db_instances.0.lock().await.insert(db.clone(), pool);
Ok(db)
}
/// Allows the database connection(s) to be closed; if no database
/// name is passed in then _all_ database connection pools will be
/// shut down.
#[command]
async fn close(db_instances: State<'_, DbInstances>, db: Option<String>) -> Result<bool> {
let mut instances = db_instances.0.lock().await;
let pools = if let Some(db) = db {
vec![db]
} else {
instances.keys().cloned().collect()
};
for pool in pools {
let db = instances
.get_mut(&pool) //
.ok_or(Error::DatabaseNotLoaded(pool))?;
db.close().await;
}
Ok(true)
}
/// Execute a command against the database
#[command]
async fn execute(
db_instances: State<'_, DbInstances>,
db: String,
query: String,
values: Vec<JsonValue>,
) -> Result<(u64, LastInsertId)> {
let mut instances = db_instances.0.lock().await;
let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?;
let mut query = sqlx::query(&query);
for value in values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let result = query.execute(&*db).await?;
#[cfg(feature = "sqlite")]
let r = Ok((result.rows_affected(), result.last_insert_rowid()));
#[cfg(feature = "mysql")]
let r = Ok((result.rows_affected(), result.last_insert_id()));
#[cfg(feature = "postgres")]
let r = Ok((result.rows_affected(), 0));
r
}
#[command]
async fn select(
db_instances: State<'_, DbInstances>,
db: String,
query: String,
values: Vec<JsonValue>,
) -> Result<Vec<IndexMap<String, JsonValue>>> {
let mut instances = db_instances.0.lock().await;
let db = instances.get_mut(&db).ok_or(Error::DatabaseNotLoaded(db))?;
let mut query = sqlx::query(&query);
for value in values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let rows = query.fetch_all(&*db).await?;
let mut values = Vec::new();
for row in rows {
let mut value = IndexMap::default();
for (i, column) in row.columns().iter().enumerate() {
let v = row.try_get_raw(i)?;
let v = crate::decode::to_json(v)?;
value.insert(column.name().to_string(), v);
}
values.push(value);
}
Ok(values)
}
/// Tauri SQL plugin builder.
#[derive(Default)]
pub struct Builder {
migrations: Option<HashMap<String, MigrationList>>,
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
/// Add migrations to a database.
#[must_use]
pub fn add_migrations(mut self, db_url: &str, migrations: Vec<Migration>) -> Self {
self.migrations
.get_or_insert(Default::default())
.insert(db_url.to_string(), MigrationList(migrations));
self
}
pub fn build<R: Runtime>(mut self) -> TauriPlugin<R, Option<PluginConfig>> {
PluginBuilder::<R, Option<PluginConfig>>::new("sql")
.invoke_handler(tauri::generate_handler![load, execute, select, close])
.setup(|app, api| {
let config = api.config().clone().unwrap_or_default();
#[cfg(feature = "sqlite")]
create_dir_all(app_path(app)).expect("problems creating App directory!");
tauri::async_runtime::block_on(async move {
let instances = DbInstances::default();
let mut lock = instances.0.lock().await;
for db in config.preload {
#[cfg(feature = "sqlite")]
let fqdb = path_mapper(app_path(app), &db);
#[cfg(not(feature = "sqlite"))]
let fqdb = db.clone();
if !Db::database_exists(&fqdb).await.unwrap_or(false) {
Db::create_database(&fqdb).await?;
}
let pool = Pool::connect(&fqdb).await?;
if let Some(migrations) = self.migrations.as_mut().unwrap().remove(&db) {
let migrator = Migrator::new(migrations).await?;
migrator.run(&pool).await?;
}
lock.insert(db, pool);
}
drop(lock);
app.manage(instances);
app.manage(Migrations(Mutex::new(
self.migrations.take().unwrap_or_default(),
)));
Ok(())
})
})
.on_event(|app, event| {
if let RunEvent::Exit = event {
tauri::async_runtime::block_on(async move {
let instances = &*app.state::<DbInstances>();
let instances = instances.0.lock().await;
for value in instances.values() {
value.close().await;
}
});
}
})
.build()
}
}

@ -0,0 +1,328 @@
// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: MIT
#[cfg(feature = "sqlite")]
use std::fs::create_dir_all;
use indexmap::IndexMap;
use serde_json::Value as JsonValue;
#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))]
use sqlx::{migrate::MigrateDatabase, Column, Executor, Pool, Row};
#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))]
use tauri::Manager;
use tauri::{AppHandle, Runtime};
#[cfg(feature = "mysql")]
use sqlx::MySql;
#[cfg(feature = "postgres")]
use sqlx::Postgres;
#[cfg(feature = "sqlite")]
use sqlx::Sqlite;
use crate::LastInsertId;
pub enum DbPool {
#[cfg(feature = "sqlite")]
Sqlite(Pool<Sqlite>),
#[cfg(feature = "mysql")]
MySql(Pool<MySql>),
#[cfg(feature = "postgres")]
Postgres(Pool<Postgres>),
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
None,
}
// public methods
/* impl DbPool {
/// Get the inner Sqlite Pool. Returns None for MySql and Postgres pools.
#[cfg(feature = "sqlite")]
pub fn sqlite(&self) -> Option<&Pool<Sqlite>> {
match self {
DbPool::Sqlite(pool) => Some(pool),
_ => None,
}
}
/// Get the inner MySql Pool. Returns None for Sqlite and Postgres pools.
#[cfg(feature = "mysql")]
pub fn mysql(&self) -> Option<&Pool<MySql>> {
match self {
DbPool::MySql(pool) => Some(pool),
_ => None,
}
}
/// Get the inner Postgres Pool. Returns None for MySql and Sqlite pools.
#[cfg(feature = "postgres")]
pub fn postgres(&self) -> Option<&Pool<Postgres>> {
match self {
DbPool::Postgres(pool) => Some(pool),
_ => None,
}
}
} */
// private methods
impl DbPool {
pub(crate) async fn connect<R: Runtime>(
conn_url: &str,
_app: &AppHandle<R>,
) -> Result<Self, crate::Error> {
match conn_url
.split_once(':')
.ok_or_else(|| crate::Error::InvalidDbUrl(conn_url.to_string()))?
.0
{
#[cfg(feature = "sqlite")]
"sqlite" => {
let app_path = _app
.path()
.app_config_dir()
.expect("No App config path was found!");
create_dir_all(&app_path).expect("Couldn't create app config dir");
let conn_url = &path_mapper(app_path, conn_url);
if !Sqlite::database_exists(conn_url).await.unwrap_or(false) {
Sqlite::create_database(conn_url).await?;
}
Ok(Self::Sqlite(Pool::connect(conn_url).await?))
}
#[cfg(feature = "mysql")]
"mysql" => {
if !MySql::database_exists(conn_url).await.unwrap_or(false) {
MySql::create_database(conn_url).await?;
}
Ok(Self::MySql(Pool::connect(conn_url).await?))
}
#[cfg(feature = "postgres")]
"postgres" => {
if !Postgres::database_exists(conn_url).await.unwrap_or(false) {
Postgres::create_database(conn_url).await?;
}
Ok(Self::Postgres(Pool::connect(conn_url).await?))
}
_ => Err(crate::Error::InvalidDbUrl(conn_url.to_string())),
}
}
pub(crate) async fn migrate(
&self,
_migrator: &sqlx::migrate::Migrator,
) -> Result<(), crate::Error> {
match self {
#[cfg(feature = "sqlite")]
DbPool::Sqlite(pool) => _migrator.run(pool).await?,
#[cfg(feature = "mysql")]
DbPool::MySql(pool) => _migrator.run(pool).await?,
#[cfg(feature = "postgres")]
DbPool::Postgres(pool) => _migrator.run(pool).await?,
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
DbPool::None => (),
}
Ok(())
}
pub(crate) async fn close(&self) {
match self {
#[cfg(feature = "sqlite")]
DbPool::Sqlite(pool) => pool.close().await,
#[cfg(feature = "mysql")]
DbPool::MySql(pool) => pool.close().await,
#[cfg(feature = "postgres")]
DbPool::Postgres(pool) => pool.close().await,
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
DbPool::None => (),
}
}
pub(crate) async fn execute(
&self,
_query: String,
_values: Vec<JsonValue>,
) -> Result<(u64, LastInsertId), crate::Error> {
Ok(match self {
#[cfg(feature = "sqlite")]
DbPool::Sqlite(pool) => {
let mut query = sqlx::query(&_query);
for value in _values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let result = pool.execute(query).await?;
(
result.rows_affected(),
LastInsertId::Sqlite(result.last_insert_rowid()),
)
}
#[cfg(feature = "mysql")]
DbPool::MySql(pool) => {
let mut query = sqlx::query(&_query);
for value in _values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let result = pool.execute(query).await?;
(
result.rows_affected(),
LastInsertId::MySql(result.last_insert_id()),
)
}
#[cfg(feature = "postgres")]
DbPool::Postgres(pool) => {
let mut query = sqlx::query(&_query);
for value in _values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let result = pool.execute(query).await?;
(result.rows_affected(), LastInsertId::Postgres(()))
}
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
DbPool::None => (0, LastInsertId::None),
})
}
pub(crate) async fn select(
&self,
_query: String,
_values: Vec<JsonValue>,
) -> Result<Vec<IndexMap<String, JsonValue>>, crate::Error> {
Ok(match self {
#[cfg(feature = "sqlite")]
DbPool::Sqlite(pool) => {
let mut query = sqlx::query(&_query);
for value in _values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let rows = pool.fetch_all(query).await?;
let mut values = Vec::new();
for row in rows {
let mut value = IndexMap::default();
for (i, column) in row.columns().iter().enumerate() {
let v = row.try_get_raw(i)?;
let v = crate::decode::sqlite::to_json(v)?;
value.insert(column.name().to_string(), v);
}
values.push(value);
}
values
}
#[cfg(feature = "mysql")]
DbPool::MySql(pool) => {
let mut query = sqlx::query(&_query);
for value in _values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let rows = pool.fetch_all(query).await?;
let mut values = Vec::new();
for row in rows {
let mut value = IndexMap::default();
for (i, column) in row.columns().iter().enumerate() {
let v = row.try_get_raw(i)?;
let v = crate::decode::mysql::to_json(v)?;
value.insert(column.name().to_string(), v);
}
values.push(value);
}
values
}
#[cfg(feature = "postgres")]
DbPool::Postgres(pool) => {
let mut query = sqlx::query(&_query);
for value in _values {
if value.is_null() {
query = query.bind(None::<JsonValue>);
} else if value.is_string() {
query = query.bind(value.as_str().unwrap().to_owned())
} else if let Some(number) = value.as_number() {
query = query.bind(number.as_f64().unwrap_or_default())
} else {
query = query.bind(value);
}
}
let rows = pool.fetch_all(query).await?;
let mut values = Vec::new();
for row in rows {
let mut value = IndexMap::default();
for (i, column) in row.columns().iter().enumerate() {
let v = row.try_get_raw(i)?;
let v = crate::decode::postgres::to_json(v)?;
value.insert(column.name().to_string(), v);
}
values.push(value);
}
values
}
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
DbPool::None => Vec::new(),
})
}
}
#[cfg(feature = "sqlite")]
/// Maps the user supplied DB connection string to a connection string
/// with a fully qualified file path to the App's designed "app_path"
fn path_mapper(mut app_path: std::path::PathBuf, connection_string: &str) -> String {
app_path.push(
connection_string
.split_once(':')
.expect("Couldn't parse the connection string for DB!")
.1,
);
format!(
"sqlite:{}",
app_path
.to_str()
.expect("Problem creating fully qualified path to Database file!")
)
}
Loading…
Cancel
Save