feat(sql): add support for SQLite pragma and encryption pragmas

- Introduced `libsqlite3-sys` dependency for SQLite support.
- Updated the `load` method to accept optional pragmas for database connections.
- Enhanced the JavaScript API to demonstrate loading databases with encryption keys and custom pragmas.
- Added VSCode settings for Rust analyzer to enable SQLite feature, to facilitate development.
- Updated Rust code to handle SQLite options and pragmas in the database connection logic.
pull/2553/head
Huakun Shen 4 months ago
parent 43f0f95da6
commit 2477559fff
No known key found for this signature in database

1
Cargo.lock generated

@ -6906,6 +6906,7 @@ version = "2.2.0"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"indexmap 2.7.0", "indexmap 2.7.0",
"libsqlite3-sys",
"log", "log",
"serde", "serde",
"serde_json", "serde_json",

@ -0,0 +1,3 @@
{
"rust-analyzer.cargo.features": ["sqlite"]
}

@ -36,8 +36,9 @@ time = "0.3"
tokio = { version = "1", features = ["sync"] } tokio = { version = "1", features = ["sync"] }
indexmap = { version = "2", features = ["serde"] } indexmap = { version = "2", features = ["serde"] }
libsqlite3-sys = { version = "0.30.1", features = ["bundled-sqlcipher"] }
[features] [features]
sqlite = ["sqlx/sqlite", "sqlx/runtime-tokio"] sqlite = ["sqlx/sqlite", "sqlx/runtime-tokio"]
mysql = ["sqlx/mysql", "sqlx/runtime-tokio-rustls"] mysql = ["sqlx/mysql", "sqlx/runtime-tokio-rustls"]
postgres = ["sqlx/postgres", "sqlx/runtime-tokio-rustls"] postgres = ["sqlx/postgres", "sqlx/runtime-tokio-rustls"]
# TODO: bundled-cipher etc

@ -1 +1 @@
if("__TAURI__"in window){var __TAURI_PLUGIN_SQL__=function(){"use strict";async function e(e,t={},s){return window.__TAURI_INTERNALS__.invoke(e,t,s)}"function"==typeof SuppressedError&&SuppressedError;class t{constructor(e){this.path=e}static async load(s){const n=await e("plugin:sql|load",{db:s});return new t(n)}static get(e){return new t(e)}async execute(t,s){const[n,r]=await e("plugin:sql|execute",{db:this.path,query:t,values:s??[]});return{lastInsertId:r,rowsAffected:n}}async select(t,s){return await e("plugin:sql|select",{db:this.path,query:t,values:s??[]})}async close(t){return await e("plugin:sql|close",{db:t})}}return t}();Object.defineProperty(window.__TAURI__,"sql",{value:__TAURI_PLUGIN_SQL__})} if("__TAURI__"in window){var __TAURI_PLUGIN_SQL__=function(){"use strict";async function e(e,t={},s){return window.__TAURI_INTERNALS__.invoke(e,t,s)}"function"==typeof SuppressedError&&SuppressedError;class t{constructor(e){this.path=e}static async load(s,n){const r=await e("plugin:sql|load",{db:s,pragmas:n?.pragmas});return new t(r)}static get(e){return new t(e)}async execute(t,s){const[n,r]=await e("plugin:sql|execute",{db:this.path,query:t,values:s??[]});return{lastInsertId:r,rowsAffected:n}}async select(t,s){return await e("plugin:sql|select",{db:this.path,query:t,values:s??[]})}async close(t){return await e("plugin:sql|close",{db:t})}}return t}();Object.defineProperty(window.__TAURI__,"sql",{value:__TAURI_PLUGIN_SQL__})}

@ -42,12 +42,29 @@ export default class Database {
* *
* @example * @example
* ```ts * ```ts
* // Basic connection
* const db = await Database.load("sqlite:test.db"); * const db = await Database.load("sqlite:test.db");
*
* // With encryption key
* const db = await Database.load("sqlite:encrypted.db", {
* pragmas: { "key": "encryption_key" }
* });
*
* // With pragmas
* const db = await Database.load("sqlite:test.db", {
* pragmas: { "journal_mode": "WAL", "foreign_keys": "ON" }
* });
* ``` * ```
*/ */
static async load(path: string): Promise<Database> { static async load(
path: string,
options?: {
pragmas?: Record<string, string>
}
): Promise<Database> {
const _path = await invoke<string>('plugin:sql|load', { const _path = await invoke<string>('plugin:sql|load', {
db: path db: path,
pragmas: options?.pragmas
}) })
return new Database(_path) return new Database(_path)

@ -7,8 +7,14 @@ use serde_json::Value as JsonValue;
use sqlx::migrate::Migrator; use sqlx::migrate::Migrator;
use tauri::{command, AppHandle, Runtime, State}; use tauri::{command, AppHandle, Runtime, State};
#[cfg(feature = "sqlite")]
use std::collections::HashMap;
use crate::{DbInstances, DbPool, Error, LastInsertId, Migrations}; use crate::{DbInstances, DbPool, Error, LastInsertId, Migrations};
#[cfg(feature = "sqlite")]
use crate::SqliteOptions;
#[cfg(not(feature = "sqlite"))]
#[command] #[command]
pub(crate) async fn load<R: Runtime>( pub(crate) async fn load<R: Runtime>(
app: AppHandle<R>, app: AppHandle<R>,
@ -28,6 +34,40 @@ pub(crate) async fn load<R: Runtime>(
Ok(db) Ok(db)
} }
#[cfg(feature = "sqlite")]
#[command]
pub(crate) async fn load<R: Runtime>(
app: AppHandle<R>,
db_instances: State<'_, DbInstances>,
migrations: State<'_, Migrations>,
db: String,
pragmas: Option<HashMap<String, String>>,
) -> Result<String, crate::Error> {
let sqlite_options = if db.starts_with("sqlite:") {
let mut options = SqliteOptions::default();
// Apply pragmas if provided
if let Some(provided_pragmas) = pragmas {
options.pragmas.extend(provided_pragmas);
}
Some(options)
} else {
None
};
let pool = DbPool::connect(&db, &app, sqlite_options).await?;
if let Some(migrations) = migrations.0.lock().await.remove(&db) {
let migrator = Migrator::new(migrations).await?;
pool.migrate(&migrator).await?;
}
db_instances.0.write().await.insert(db.clone(), pool);
Ok(db)
}
/// Allows the database connection(s) to be closed; if no database /// Allows the database connection(s) to be closed; if no database
/// name is passed in then _all_ database connection pools will be /// name is passed in then _all_ database connection pools will be
/// shut down. /// shut down.
@ -78,3 +118,17 @@ pub(crate) async fn select(
let db = instances.get(&db).ok_or(Error::DatabaseNotLoaded(db))?; let db = instances.get(&db).ok_or(Error::DatabaseNotLoaded(db))?;
db.select(query, values).await db.select(query, values).await
} }
// #[command]
// pub(crate) async fn query(
// db_instances: State<'_, DbInstances>,
// db: String,
// query: String,
// values: Vec<JsonValue>,
// ) -> Result<Vec<IndexMap<String, JsonValue>>, crate::Error> {
// let instances = db_instances.0.read().await;
// let db = instances.get(&db).ok_or(Error::DatabaseNotLoaded(db))?;
// db.
// // db.select(query, values).await
// }

@ -16,6 +16,8 @@ mod wrapper;
pub use error::Error; pub use error::Error;
pub use wrapper::DbPool; pub use wrapper::DbPool;
#[cfg(feature = "sqlite")]
pub use wrapper::SqliteOptions;
use futures_core::future::BoxFuture; use futures_core::future::BoxFuture;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -23,6 +25,8 @@ use sqlx::{
error::BoxDynError, error::BoxDynError,
migrate::{Migration as SqlxMigration, MigrationSource, MigrationType, Migrator}, migrate::{Migration as SqlxMigration, MigrationSource, MigrationType, Migrator},
}; };
#[cfg(feature = "sqlite")]
use sqlx::sqlite::SqliteConnectOptions;
use tauri::{ use tauri::{
plugin::{Builder as PluginBuilder, TauriPlugin}, plugin::{Builder as PluginBuilder, TauriPlugin},
Manager, RunEvent, Runtime, Manager, RunEvent, Runtime,
@ -34,6 +38,9 @@ use std::collections::HashMap;
#[derive(Default)] #[derive(Default)]
pub struct DbInstances(pub RwLock<HashMap<String, DbPool>>); pub struct DbInstances(pub RwLock<HashMap<String, DbPool>>);
#[cfg(feature = "sqlite")]
struct SqlLiteOptionStore(Mutex<HashMap<String, SqliteConnectOptions>>);
#[derive(Serialize)] #[derive(Serialize)]
#[serde(untagged)] #[serde(untagged)]
pub(crate) enum LastInsertId { pub(crate) enum LastInsertId {
@ -137,6 +144,9 @@ impl Builder {
pub fn build<R: Runtime>(mut self) -> TauriPlugin<R, Option<PluginConfig>> { pub fn build<R: Runtime>(mut self) -> TauriPlugin<R, Option<PluginConfig>> {
PluginBuilder::<R, Option<PluginConfig>>::new("sql") PluginBuilder::<R, Option<PluginConfig>>::new("sql")
.invoke_handler(tauri::generate_handler![ .invoke_handler(tauri::generate_handler![
#[cfg(feature = "sqlite")]
commands::load,
#[cfg(not(feature = "sqlite"))]
commands::load, commands::load,
commands::execute, commands::execute,
commands::select, commands::select,
@ -150,6 +160,10 @@ impl Builder {
let mut lock = instances.0.write().await; let mut lock = instances.0.write().await;
for db in config.preload { for db in config.preload {
#[cfg(feature = "sqlite")]
let pool = DbPool::connect(&db, app, None).await?;
#[cfg(not(feature = "sqlite"))]
let pool = DbPool::connect(&db, app).await?; let pool = DbPool::connect(&db, app).await?;
if let Some(migrations) = if let Some(migrations) =

@ -4,9 +4,13 @@
#[cfg(feature = "sqlite")] #[cfg(feature = "sqlite")]
use std::fs::create_dir_all; use std::fs::create_dir_all;
#[cfg(feature = "sqlite")]
use std::collections::HashMap;
use indexmap::IndexMap; use indexmap::IndexMap;
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
#[cfg(feature = "sqlite")]
use sqlx::sqlite::SqliteConnectOptions;
#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))] #[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))]
use sqlx::{migrate::MigrateDatabase, Column, Executor, Pool, Row}; use sqlx::{migrate::MigrateDatabase, Column, Executor, Pool, Row};
#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))] #[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))]
@ -33,6 +37,20 @@ pub enum DbPool {
None, None,
} }
#[cfg(feature = "sqlite")]
pub struct SqliteOptions {
pub pragmas: HashMap<String, String>,
}
#[cfg(feature = "sqlite")]
impl Default for SqliteOptions {
fn default() -> Self {
Self {
pragmas: HashMap::new(),
}
}
}
// public methods // public methods
/* impl DbPool { /* impl DbPool {
/// Get the inner Sqlite Pool. Returns None for MySql and Postgres pools. /// Get the inner Sqlite Pool. Returns None for MySql and Postgres pools.
@ -68,6 +86,7 @@ impl DbPool {
pub(crate) async fn connect<R: Runtime>( pub(crate) async fn connect<R: Runtime>(
conn_url: &str, conn_url: &str,
_app: &AppHandle<R>, _app: &AppHandle<R>,
#[cfg(feature = "sqlite")] sqlite_options: Option<SqliteOptions>,
) -> Result<Self, crate::Error> { ) -> Result<Self, crate::Error> {
match conn_url match conn_url
.split_once(':') .split_once(':')
@ -82,13 +101,22 @@ impl DbPool {
.expect("No App config path was found!"); .expect("No App config path was found!");
create_dir_all(&app_path).expect("Couldn't create app config dir"); create_dir_all(&app_path).expect("Couldn't create app config dir");
let conn_url = &path_mapper(app_path, conn_url); let conn_url = &path_mapper(app_path, conn_url);
let filename = conn_url.split_once(':').unwrap().1;
if !Sqlite::database_exists(conn_url).await.unwrap_or(false) { let mut options = SqliteConnectOptions::new()
Sqlite::create_database(conn_url).await?; .filename(filename)
.create_if_missing(true);
// Apply pragmas if provided
if let Some(sqlite_opts) = sqlite_options {
for (pragma_name, pragma_value) in sqlite_opts.pragmas {
options = options.pragma(pragma_name, pragma_value);
}
} }
Ok(Self::Sqlite(Pool::connect(conn_url).await?))
// Connect with options (which includes create_if_missing)
Ok(Self::Sqlite(Pool::connect_with(options).await?))
} }
#[cfg(feature = "mysql")] #[cfg(feature = "mysql")]
"mysql" => { "mysql" => {

Loading…
Cancel
Save