// Copyright 2019-2023 Tauri Programme within The Commons Conservancy // SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: MIT //! [![](https://github.com/tauri-apps/plugins-workspace/raw/v2/plugins/websocket/banner.png)](https://github.com/tauri-apps/plugins-workspace/tree/v2/plugins/websocket) //! //! Expose a WebSocket server to your Tauri frontend. #![doc( html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png", html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png" )] use futures_util::{stream::SplitSink, SinkExt, StreamExt}; use http::header::{HeaderName, HeaderValue}; use serde::{ser::Serializer, Deserialize, Serialize}; use tauri::{ ipc::Channel, plugin::{Builder as PluginBuilder, TauriPlugin}, Manager, Runtime, State, Window, }; use tokio::{net::TcpStream, sync::Mutex}; use tokio_tungstenite::{ connect_async_with_config, tungstenite::{ client::IntoClientRequest, protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig}, Message, }, MaybeTlsStream, WebSocketStream, }; use std::collections::HashMap; use std::str::FromStr; type Id = u32; type WebSocket = WebSocketStream>; type WebSocketWriter = SplitSink; type Result = std::result::Result; #[derive(Debug, thiserror::Error)] enum Error { #[error(transparent)] Websocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("connection not found for the given id: {0}")] ConnectionNotFound(Id), #[error(transparent)] InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), #[error(transparent)] InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), } impl Serialize for Error { fn serialize(&self, serializer: S) -> std::result::Result where S: Serializer, { serializer.serialize_str(self.to_string().as_str()) } } #[derive(Default)] struct ConnectionManager(Mutex>); #[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub struct ConnectionConfig { pub write_buffer_size: Option, pub max_write_buffer_size: Option, pub max_message_size: Option, pub max_frame_size: Option, #[serde(default)] pub accept_unmasked_frames: bool, pub headers: Option>, } impl From for WebSocketConfig { fn from(config: ConnectionConfig) -> Self { // Disabling the warning on max_send_queue which we don't use anymore since it was deprecated. #[allow(deprecated)] Self { max_send_queue: None, write_buffer_size: config.write_buffer_size.unwrap_or(128 * 1024), max_write_buffer_size: config.max_write_buffer_size.unwrap_or(usize::MAX), // This may be harmful since if it's not provided from js we're overwriting the default value with None, meaning no size limit. max_message_size: config.max_message_size, // This may be harmful since if it's not provided from js we're overwriting the default value with None, meaning no size limit. max_frame_size: config.max_frame_size, accept_unmasked_frames: config.accept_unmasked_frames, } } } #[derive(Deserialize, Serialize)] struct CloseFrame { pub code: u16, pub reason: String, } #[derive(Deserialize, Serialize)] #[serde(tag = "type", content = "data")] enum WebSocketMessage { Text(String), Binary(Vec), Ping(Vec), Pong(Vec), Close(Option), } #[tauri::command] async fn connect( window: Window, url: String, on_message: Channel, config: Option, ) -> Result { let id = rand::random(); let mut request = url.into_client_request()?; if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) { for (k, v) in headers { let header_name = HeaderName::from_str(k.as_str())?; let header_value = HeaderValue::from_str(v.as_str())?; request.headers_mut().insert(header_name, header_value); } } let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split(); let manager = window.state::(); manager.0.lock().await.insert(id, write); read.for_each(move |message| { let window_ = window.clone(); let on_message_ = on_message.clone(); async move { if let Ok(Message::Close(_)) = message { let manager = window_.state::(); manager.0.lock().await.remove(&id); } let response = match message { Ok(Message::Text(t)) => { serde_json::to_value(WebSocketMessage::Text(t)).unwrap() } Ok(Message::Binary(t)) => { serde_json::to_value(WebSocketMessage::Binary(t)).unwrap() } Ok(Message::Ping(t)) => { serde_json::to_value(WebSocketMessage::Ping(t)).unwrap() } Ok(Message::Pong(t)) => { serde_json::to_value(WebSocketMessage::Pong(t)).unwrap() } Ok(Message::Close(t)) => { serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame { code: v.code.into(), reason: v.reason.into_owned(), }))) .unwrap() } Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be recieved. Err(e) => serde_json::to_value(Error::from(e)).unwrap(), }; let _ = on_message_.send(response); } }) .await; }); Ok(id) } #[tauri::command] async fn send( manager: State<'_, ConnectionManager>, id: Id, message: WebSocketMessage, ) -> Result<()> { if let Some(write) = manager.0.lock().await.get_mut(&id) { write .send(match message { WebSocketMessage::Text(t) => Message::Text(t), WebSocketMessage::Binary(t) => Message::Binary(t), WebSocketMessage::Ping(t) => Message::Ping(t), WebSocketMessage::Pong(t) => Message::Pong(t), WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame { code: v.code.into(), reason: std::borrow::Cow::Owned(v.reason), })), }) .await?; Ok(()) } else { Err(Error::ConnectionNotFound(id)) } } pub fn init() -> TauriPlugin { PluginBuilder::new("websocket") .js_init_script(include_str!("api-iife.js").to_string()) .invoke_handler(tauri::generate_handler![connect, send]) .setup(|app, _api| { app.manage(ConnectionManager::default()); Ok(()) }) .build() }