diff --git a/plugins/window-state/src/lib.rs b/plugins/window-state/src/lib.rs index 1f4abdd6..a5cf5847 100644 --- a/plugins/window-state/src/lib.rs +++ b/plugins/window-state/src/lib.rs @@ -14,13 +14,14 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; use tauri::{ plugin::{Builder as PluginBuilder, TauriPlugin}, - Manager, Monitor, PhysicalPosition, PhysicalSize, RunEvent, Runtime, WebviewWindow, Window, - WindowEvent, + AppHandle, Manager, Monitor, PhysicalPosition, PhysicalSize, RunEvent, Runtime, WebviewWindow, + Window, WindowEvent, }; use std::{ collections::{HashMap, HashSet}, - fs::{create_dir_all, File}, + fs::create_dir_all, + io::BufReader, sync::{Arc, Mutex}, }; @@ -106,6 +107,7 @@ impl Default for WindowState { struct WindowStateCache(Arc>>); /// Used to prevent deadlocks from resize and position event listeners setting the cached state on restoring states struct RestoringWindowState(Mutex<()>); + pub trait AppHandleExt { /// Saves all open windows state to disk fn save_window_state(&self, flags: StateFlags) -> Result<()>; @@ -115,33 +117,31 @@ pub trait AppHandleExt { impl AppHandleExt for tauri::AppHandle { fn save_window_state(&self, flags: StateFlags) -> Result<()> { - if let Ok(app_dir) = self.path().app_config_dir() { - let plugin_state = self.state::(); - let state_path = app_dir.join(&plugin_state.filename); - let windows = self.webview_windows(); - let cache = self.state::(); - let mut state = cache.0.lock().unwrap(); - - for (label, s) in state.iter_mut() { - let window = match &plugin_state.map_label { - Some(map) => windows - .iter() - .find_map(|(l, window)| (map(l) == label).then_some(window)), - None => windows.get(label), - }; - - if let Some(window) = window { - window.update_state(s, flags)?; - } + let app_dir = self.path().app_config_dir()?; + let plugin_state = self.state::(); + let state_path = app_dir.join(&plugin_state.filename); + let windows = self.webview_windows(); + let cache = self.state::(); + let mut state = cache.0.lock().unwrap(); + + for (label, s) in state.iter_mut() { + let window = if let Some(map) = &plugin_state.map_label { + windows + .iter() + .find_map(|(l, window)| (map(l) == label).then_some(window)) + } else { + windows.get(label) + }; + + if let Some(window) = window { + window.update_state(s, flags)?; } - - create_dir_all(&app_dir) - .map_err(Error::Io) - .and_then(|_| File::create(state_path).map_err(Into::into)) - .and_then(|mut f| serde_json::to_writer_pretty(&mut f, &*state).map_err(Into::into)) - } else { - Ok(()) } + + create_dir_all(app_dir)?; + std::fs::write(state_path, serde_json::to_vec_pretty(&*state)?)?; + + Ok(()) } fn filename(&self) -> String { @@ -159,6 +159,7 @@ impl WindowExt for WebviewWindow { self.as_ref().window().restore_state(flags) } } + impl WindowExt for Window { fn restore_state(&self, flags: StateFlags) -> tauri::Result<()> { let plugin_state = self.app_handle().state::(); @@ -352,7 +353,7 @@ impl Builder { self } - /// Sets a filter callback to exclude specific windows from being tracked. + /// Sets a filter callback to exclude specific windows from being tracked. /// Return `true` to save the state, or `false` to skip and not save it. pub fn with_filter(mut self, filter_callback: F) -> Self where @@ -391,25 +392,8 @@ impl Builder { cmd::filename ]) .setup(|app, _api| { - let cache: Arc>> = - if let Ok(app_dir) = app.path().app_config_dir() { - let state_path = app_dir.join(&filename); - if state_path.exists() { - Arc::new(Mutex::new( - std::fs::read(state_path) - .map_err(Error::from) - .and_then(|state| { - serde_json::from_slice(&state).map_err(Into::into) - }) - .unwrap_or_default(), - )) - } else { - Default::default() - } - } else { - Default::default() - }; - app.manage(WindowStateCache(cache)); + let cache = load_saved_window_states(app, &filename).unwrap_or_default(); + app.manage(WindowStateCache(Arc::new(Mutex::new(cache)))); app.manage(RestoringWindowState(Mutex::new(()))); app.manage(PluginState { filename, @@ -522,6 +506,18 @@ impl Builder { } } +fn load_saved_window_states( + app: &AppHandle, + filename: &String, +) -> Result> { + let app_dir = app.path().app_config_dir()?; + let state_path = app_dir.join(filename); + let file = std::fs::File::open(state_path)?; + let reader = BufReader::new(file); + let states = serde_json::from_reader(reader)?; + Ok(states) +} + trait MonitorExt { fn intersects(&self, position: PhysicalPosition, size: PhysicalSize) -> bool; }