use futures_util::{stream::SplitSink, SinkExt, StreamExt}; use serde::{ser::Serializer, Deserialize, Serialize}; use tauri::{ api::ipc::{format_callback, CallbackFn}, plugin::{Builder as PluginBuilder, TauriPlugin}, Manager, Runtime, State, Window, }; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{ connect_async_with_config, tungstenite::{ protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig}, Message, }, MaybeTlsStream, WebSocketStream, }; use std::collections::HashMap; type Id = u32; type WebSocket = WebSocketStream>; type WebSocketWriter = SplitSink; type Result = std::result::Result; #[derive(Debug, thiserror::Error)] enum Error { #[error(transparent)] Websocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("connection not found for the given id: {0}")] ConnectionNotFound(Id), } impl Serialize for Error { fn serialize(&self, serializer: S) -> std::result::Result where S: Serializer, { serializer.serialize_str(self.to_string().as_str()) } } #[derive(Default)] struct ConnectionManager(Mutex>); #[derive(Default, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ConnectionConfig { pub max_send_queue: Option, pub max_message_size: Option, pub max_frame_size: Option, pub accept_unmasked_frames: bool, } impl From for WebSocketConfig { fn from(config: ConnectionConfig) -> Self { Self { max_send_queue: config.max_send_queue, max_message_size: config.max_message_size, max_frame_size: config.max_frame_size, accept_unmasked_frames: config.accept_unmasked_frames, } } } #[derive(Deserialize, Serialize)] struct CloseFrame { pub code: u16, pub reason: String, } #[derive(Deserialize, Serialize)] #[serde(tag = "type", content = "data")] enum WebSocketMessage { Text(String), Binary(Vec), Ping(Vec), Pong(Vec), Close(Option), } #[tauri::command] fn connect( window: Window, url: String, callback_function: CallbackFn, config: Option, ) -> Result { let id = rand::random(); let (ws_stream, _) = tauri::async_runtime::block_on(connect_async_with_config(url, config.map(Into::into)))?; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split(); let manager = window.state::(); manager.0.lock().await.insert(id, write); read.for_each(move |message| { let window_ = window.clone(); async move { if let Ok(Message::Close(_)) = message { let manager = window_.state::(); manager.0.lock().await.remove(&id); } let response = match message { Ok(Message::Text(t)) => { serde_json::to_value(WebSocketMessage::Text(t)).unwrap() } Ok(Message::Binary(t)) => { serde_json::to_value(WebSocketMessage::Binary(t)).unwrap() } Ok(Message::Ping(t)) => { serde_json::to_value(WebSocketMessage::Ping(t)).unwrap() } Ok(Message::Pong(t)) => { serde_json::to_value(WebSocketMessage::Pong(t)).unwrap() } Ok(Message::Close(t)) => { serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame { code: v.code.into(), reason: v.reason.into_owned(), }))) .unwrap() } Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be recieved. Err(e) => serde_json::to_value(Error::from(e)).unwrap(), }; let js = format_callback(callback_function, &response) .expect("unable to serialize websocket message"); let _ = window_.eval(js.as_str()); } }) .await; }); Ok(id) } #[tauri::command] async fn send( manager: State<'_, ConnectionManager>, id: Id, message: WebSocketMessage, ) -> Result<()> { if let Some(write) = manager.0.lock().await.get_mut(&id) { write .send(match message { WebSocketMessage::Text(t) => Message::Text(t), WebSocketMessage::Binary(t) => Message::Binary(t), WebSocketMessage::Ping(t) => Message::Ping(t), WebSocketMessage::Pong(t) => Message::Pong(t), WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame { code: v.code.into(), reason: std::borrow::Cow::Owned(v.reason), })), }) .await?; Ok(()) } else { Err(Error::ConnectionNotFound(id)) } } pub fn init() -> TauriPlugin { PluginBuilder::new("websocket") .invoke_handler(tauri::generate_handler![connect, send]) .setup(|app| { app.manage(ConnectionManager::default()); Ok(()) }) .build() }