diff --git a/plugins/sql/guest-js/index.ts b/plugins/sql/guest-js/index.ts index 11d39e70..5a468c7c 100644 --- a/plugins/sql/guest-js/index.ts +++ b/plugins/sql/guest-js/index.ts @@ -18,6 +18,15 @@ export interface QueryResult { lastInsertId?: number } +export interface ConnectionOptions { + sqlite?: { + pool?: { + max_connections?: number + min_connections?: number + } + } +} + /** * **Database** * @@ -45,8 +54,12 @@ export default class Database { * const db = await Database.load("sqlite:test.db"); * ``` */ - static async load(path: string): Promise { + static async load( + path: string, + options?: ConnectionOptions + ): Promise { const _path = await invoke('plugin:sql|load', { + options, db: path }) diff --git a/plugins/sql/src/commands.rs b/plugins/sql/src/commands.rs index 760d00b2..5b40d230 100644 --- a/plugins/sql/src/commands.rs +++ b/plugins/sql/src/commands.rs @@ -7,7 +7,7 @@ use serde_json::Value as JsonValue; use sqlx::migrate::Migrator; use tauri::{command, AppHandle, Runtime, State}; -use crate::{DbInstances, DbPool, Error, LastInsertId, Migrations}; +use crate::{wrapper::ConnectionOptions, DbInstances, DbPool, Error, LastInsertId, Migrations}; #[command] pub(crate) async fn load( @@ -15,8 +15,9 @@ pub(crate) async fn load( db_instances: State<'_, DbInstances>, migrations: State<'_, Migrations>, db: String, + options: Option, ) -> Result { - let pool = DbPool::connect(&db, &app).await?; + let pool = DbPool::connect(&db, &app, options).await?; if let Some(migrations) = migrations.0.lock().await.remove(&db) { let migrator = Migrator::new(migrations).await?; diff --git a/plugins/sql/src/lib.rs b/plugins/sql/src/lib.rs index 56b2a3a6..7c981ce4 100644 --- a/plugins/sql/src/lib.rs +++ b/plugins/sql/src/lib.rs @@ -150,7 +150,7 @@ impl Builder { let mut lock = instances.0.write().await; for db in config.preload { - let pool = DbPool::connect(&db, app).await?; + let pool = DbPool::connect(&db, app, None).await?; if let Some(migrations) = self.migrations.as_mut().and_then(|mm| mm.remove(&db)) diff --git a/plugins/sql/src/wrapper.rs b/plugins/sql/src/wrapper.rs index 54f124be..23fde475 100644 --- a/plugins/sql/src/wrapper.rs +++ b/plugins/sql/src/wrapper.rs @@ -33,6 +33,25 @@ pub enum DbPool { None, } +#[cfg(feature = "sqlite")] +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct SqlitePoolOptions { + max_connections: Option, + min_connections: Option, +} + +#[cfg(feature = "sqlite")] +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct SqliteOptions { + pub pool: Option, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct ConnectionOptions { + #[cfg(feature = "sqlite")] + pub sqlite: Option, +} + // public methods /* impl DbPool { /// Get the inner Sqlite Pool. Returns None for MySql and Postgres pools. @@ -68,6 +87,7 @@ impl DbPool { pub(crate) async fn connect( conn_url: &str, _app: &AppHandle, + options: Option, ) -> Result { match conn_url .split_once(':') @@ -89,11 +109,23 @@ impl DbPool { Sqlite::create_database(conn_url).await?; } - let pool = sqlx::sqlite::SqlitePoolOptions::new() - .max_connections(1) - .connect(conn_url); + let mut pool_options = sqlx::sqlite::SqlitePoolOptions::new(); + + let sqlite_pool_options = options + .and_then(|opts| opts.sqlite) + .and_then(|sqlite_opts| sqlite_opts.pool); + + if let Some(custom_pool_opts) = sqlite_pool_options { + if let Some(max_connections) = custom_pool_opts.max_connections { + pool_options = pool_options.max_connections(max_connections); + } + + if let Some(min_connections) = custom_pool_opts.min_connections { + pool_options = pool_options.min_connections(min_connections); + } + } - Ok(Self::Sqlite(pool.await?)) + Ok(Self::Sqlite(pool_options.connect(conn_url).await?)) } #[cfg(feature = "mysql")] "mysql" => {