// Copyright 2019-2023 Tauri Programme within The Commons Conservancy // SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: MIT //! Open a WebSocket connection using a Rust client in JS. #![doc( html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png", html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png" )] use futures_util::{stream::SplitSink, SinkExt, StreamExt}; use http::header::{HeaderName, HeaderValue}; use serde::{ser::Serializer, Deserialize, Serialize}; use tauri::{ ipc::Channel, plugin::{Builder as PluginBuilder, TauriPlugin}, Manager, Runtime, State, Window, }; use tokio::{net::TcpStream, sync::Mutex}; #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] use tokio_tungstenite::connect_async_tls_with_config; #[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))] use tokio_tungstenite::connect_async_with_config; use tokio_tungstenite::{ tungstenite::{ client::IntoClientRequest, protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig}, Message, }, Connector, MaybeTlsStream, WebSocketStream, }; use std::collections::HashMap; use std::str::FromStr; type Id = u32; type WebSocket = WebSocketStream>; type WebSocketWriter = SplitSink; type Result = std::result::Result; #[derive(Debug, thiserror::Error)] enum Error { #[error(transparent)] Websocket(#[from] tokio_tungstenite::tungstenite::Error), #[error("connection not found for the given id: {0}")] ConnectionNotFound(Id), #[error(transparent)] InvalidHeaderValue(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue), #[error(transparent)] InvalidHeaderName(#[from] tokio_tungstenite::tungstenite::http::header::InvalidHeaderName), } impl Serialize for Error { fn serialize(&self, serializer: S) -> std::result::Result where S: Serializer, { serializer.serialize_str(self.to_string().as_str()) } } #[derive(Default)] struct ConnectionManager(Mutex>); #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] struct TlsConnector(Mutex>); #[derive(Deserialize)] #[serde(untagged, rename_all = "camelCase")] enum Max { None, Number(usize), } #[derive(Deserialize)] #[serde(rename_all = "camelCase")] pub(crate) struct ConnectionConfig { pub read_buffer_size: Option, pub write_buffer_size: Option, pub max_write_buffer_size: Option, pub max_message_size: Option, pub max_frame_size: Option, #[serde(default)] pub accept_unmasked_frames: bool, pub headers: Option>, } impl From for WebSocketConfig { fn from(config: ConnectionConfig) -> Self { let mut builder = WebSocketConfig::default().accept_unmasked_frames(config.accept_unmasked_frames); if let Some(read_buffer_size) = config.read_buffer_size { builder = builder.read_buffer_size(read_buffer_size) } if let Some(write_buffer_size) = config.write_buffer_size { builder = builder.write_buffer_size(write_buffer_size) } if let Some(max_write_buffer_size) = config.max_write_buffer_size { builder = builder.max_write_buffer_size(max_write_buffer_size) } if let Some(max_message_size) = config.max_message_size { let max_size = match max_message_size { Max::None => Option::None, Max::Number(n) => Some(n), }; builder = builder.max_message_size(max_size); } if let Some(max_frame_size) = config.max_frame_size { let max_size = match max_frame_size { Max::None => Option::None, Max::Number(n) => Some(n), }; builder = builder.max_frame_size(max_size); } builder } } #[derive(Deserialize, Serialize)] struct CloseFrame { pub code: u16, pub reason: String, } #[derive(Deserialize, Serialize)] #[serde(tag = "type", content = "data")] enum WebSocketMessage { Text(String), Binary(Vec), Ping(Vec), Pong(Vec), Close(Option), } #[tauri::command] async fn connect( window: Window, url: String, on_message: Channel, config: Option, ) -> Result { let id = rand::random(); let mut request = url.into_client_request()?; if let Some(headers) = config.as_ref().and_then(|c| c.headers.as_ref()) { for (k, v) in headers { let header_name = HeaderName::from_str(k.as_str())?; let header_value = HeaderValue::from_str(v.as_str())?; request.headers_mut().insert(header_name, header_value); } } #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] let tls_connector = match window.try_state::() { Some(tls_connector) => tls_connector.0.lock().await.clone(), None => None, }; #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] let (ws_stream, _) = connect_async_tls_with_config(request, config.map(Into::into), false, tls_connector) .await?; #[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))] let (ws_stream, _) = connect_async_with_config(request, config.map(Into::into), false).await?; tauri::async_runtime::spawn(async move { let (write, read) = ws_stream.split(); let manager = window.state::(); manager.0.lock().await.insert(id, write); read.for_each(move |message| { let window_ = window.clone(); let on_message_ = on_message.clone(); async move { if let Ok(Message::Close(_)) = message { let manager = window_.state::(); manager.0.lock().await.remove(&id); } let response = match message { Ok(Message::Text(t)) => { serde_json::to_value(WebSocketMessage::Text(t.to_string())).unwrap() } Ok(Message::Binary(t)) => { serde_json::to_value(WebSocketMessage::Binary(t.to_vec())).unwrap() } Ok(Message::Ping(t)) => { serde_json::to_value(WebSocketMessage::Ping(t.to_vec())).unwrap() } Ok(Message::Pong(t)) => { serde_json::to_value(WebSocketMessage::Pong(t.to_vec())).unwrap() } Ok(Message::Close(t)) => { serde_json::to_value(WebSocketMessage::Close(t.map(|v| CloseFrame { code: v.code.into(), reason: v.reason.to_string(), }))) .unwrap() } Ok(Message::Frame(_)) => serde_json::Value::Null, // This value can't be recieved. Err(e) => serde_json::to_value(Error::from(e)).unwrap(), }; let _ = on_message_.send(response); } }) .await; }); Ok(id) } #[tauri::command] async fn send( manager: State<'_, ConnectionManager>, id: Id, message: WebSocketMessage, ) -> Result<()> { if let Some(write) = manager.0.lock().await.get_mut(&id) { write .send(match message { WebSocketMessage::Text(t) => Message::Text(t.into()), WebSocketMessage::Binary(t) => Message::Binary(t.into()), WebSocketMessage::Ping(t) => Message::Ping(t.into()), WebSocketMessage::Pong(t) => Message::Pong(t.into()), WebSocketMessage::Close(t) => Message::Close(t.map(|v| ProtocolCloseFrame { code: v.code.into(), reason: v.reason.into(), })), }) .await?; Ok(()) } else { Err(Error::ConnectionNotFound(id)) } } pub fn init() -> TauriPlugin { Builder::default().build() } #[derive(Default)] pub struct Builder { tls_connector: Option, } impl Builder { pub fn new() -> Self { Self { tls_connector: None, } } pub fn tls_connector(mut self, connector: Connector) -> Self { self.tls_connector.replace(connector); self } pub fn build(self) -> TauriPlugin { PluginBuilder::new("websocket") .invoke_handler(tauri::generate_handler![connect, send]) .setup(|app, _api| { app.manage(ConnectionManager::default()); #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] app.manage(TlsConnector(Mutex::new(self.tls_connector))); Ok(()) }) .build() } }